From 05e194cb9dfdbcc10be0b47eb9bd0a7c6e7568d2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 18 Jan 2020 05:50:38 +0000 Subject: [PATCH] * fix not use int8 and uint8 * fix pylint * print debug --- python/mxnet/ndarray/numpy/linalg.py | 267 +++++++++- python/mxnet/numpy/linalg.py | 265 +++++++++- python/mxnet/numpy_dispatch_protocol.py | 4 + python/mxnet/symbol/numpy/linalg.py | 206 +++++++- python/mxnet/test_utils.py | 46 ++ src/operator/c_lapack_api.cc | 13 + src/operator/c_lapack_api.h | 62 +++ src/operator/numpy/linalg/np_eig-inl.h | 274 +++++++++++ src/operator/numpy/linalg/np_eig.cc | 157 ++++++ src/operator/numpy/linalg/np_eig.cu | 42 ++ src/operator/numpy/linalg/np_eigvals-inl.h | 465 ++++++++++++++++++ src/operator/numpy/linalg/np_eigvals.cc | 122 +++++ src/operator/numpy/linalg/np_eigvals.cu | 42 ++ .../unittest/test_numpy_interoperability.py | 28 ++ tests/python/unittest/test_numpy_op.py | 328 ++++++++++++ 15 files changed, 2318 insertions(+), 3 deletions(-) create mode 100644 src/operator/numpy/linalg/np_eig-inl.h create mode 100644 src/operator/numpy/linalg/np_eig.cc create mode 100644 src/operator/numpy/linalg/np_eig.cu create mode 100644 src/operator/numpy/linalg/np_eigvals-inl.h create mode 100644 src/operator/numpy/linalg/np_eigvals.cc create mode 100644 src/operator/numpy/linalg/np_eigvals.cu diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 51be85851a9b..5cb269ef3d0b 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,8 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', + 'eigvals', 'eig', 'eigvalsh', 'eigh'] def pinv(a, rcond=1e-15, hermitian=False): @@ -581,3 +582,267 @@ def tensorsolve(a, b, axes=None): True """ return _npi.tensorsolve(a, b, axes) + + +def eigvals(a): + r""" + Compute the eigenvalues of a general matrix. + + Main difference between `eigvals` and `eig`: the eigenvectors aren't + returned. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues will be computed. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues, each repeated according to its multiplicity. + They are not necessarily ordered. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + This function differs from the original `numpy.linalg.eigvals + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + Illustration, using the fact that the eigenvalues of a diagonal matrix + are its diagonal elements, that multiplying a matrix on the left + by an orthogonal matrix, `Q`, and on the right by `Q.T` (the transpose + of `Q`), preserves the eigenvalues of the "middle" matrix. In other words, + if `Q` is orthogonal, then ``Q * A * Q.T`` has the same eigenvalues as + ``A``: + >>> from numpy import linalg as LA + >>> x = np.random.random() + >>> Q = np.array([[np.cos(x), -np.sin(x)], [np.sin(x), np.cos(x)]]) + >>> LA.norm(Q[0, :]), LA.norm(Q[1, :]), np.dot(Q[0, :],Q[1, :]) + (1.0, 1.0, 0.0) + + Now multiply a diagonal matrix by ``Q`` on one side and by ``Q.T`` on the other: + >>> D = np.diag((-1,1)) + >>> LA.eigvals(D) + array([-1., 1.]) + >>> A = np.dot(Q, D) + >>> A = np.dot(A, Q.T) + >>> LA.eigvals(A) + array([ 1., -1.]) # random + """ + return _npi.eigvals(a) + + +def eigvalsh(a, UPLO='L'): + r""" + Compute the eigenvalues real symmetric matrix. + + Main difference from eigh: the eigenvectors are not computed. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The eigenvalues are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigvalsh + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[ 5.4119368 , 8.996273 , -5.086096 ], + [ 0.8866155 , 1.7490431 , -4.6107802 ], + [-0.08034172, 4.4172044 , 1.4528792 ]]) + >>> LA.eigvalsh(a, UPLO='L') + array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order + """ + return _npi.eigvalsh(a, UPLO) + + +def eig(a): + r""" + Compute the eigenvalues and right eigenvectors of a square array. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrices for which the eigenvalues and right eigenvectors will + be computed + + Returns + ------- + w : (..., M) ndarray + The eigenvalues, each repeated according to its multiplicity. + The eigenvalues are not necessarily ordered. + v : (..., M, M) ndarray + The normalized (unit "length") eigenvectors, such that the + column ``v[:,i]`` is the eigenvector corresponding to the + eigenvalue ``w[i]``. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + The number `w` is an eigenvalue of `a` if there exists a vector + `v` such that ``dot(a,v) = w * v``. Thus, the arrays `a`, `w`, and + `v` satisfy the equations ``dot(a[:,:], v[:,i]) = w[i] * v[:,i]`` + for :math:`i \\in \\{0,...,M-1\\}`. + + The array `v` of eigenvectors may not be of maximum rank, that is, some + of the columns may be linearly dependent, although round-off error may + obscure that fact. If the eigenvalues are all different, then theoretically + the eigenvectors are linearly independent. + + This function differs from the original `numpy.linalg.eig + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[-1.9147992 , 6.054115 , 18.046988 ], + [ 0.77563655, -4.860152 , 2.1012988 ], + [ 2.6083658 , 2.3705218 , 0.3192524 ]]) + >>> w, v = LA.eig(a) + >>> w + array([ 6.9683027, -7.768063 , -5.655937 ]) + >>> v + array([[ 0.90617794, 0.9543622 , 0.2492316 ], + [ 0.13086087, -0.04077047, -0.9325615 ], + [ 0.4021404 , -0.29585576, 0.26117516]]) + """ + w, v = _npi.eig(a) + return (w, v) + + +def eigh(a, UPLO='L'): + r""" + Return the eigenvalues and eigenvectors real symmetric matrix. + + Returns two objects, a 1-D array containing the eigenvalues of `a`, and + a 2-D square array or matrix (depending on the input type) of the + corresponding eigenvectors (in columns). + + Parameters + ---------- + a : (..., M, M) ndarray + real symmetric matrices whose eigenvalues and eigenvectors are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + v : {(..., M, M) ndarray, (..., M, M) matrix} + The column ``v[:, i]`` is the normalized eigenvector corresponding + to the eigenvalue ``w[i]``. Will return a matrix object if `a` is + a matrix object. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + The eigenvalues/eigenvectors are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigh + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[ 6.8189726 , -3.926585 , 4.3990498 ], + [-0.59656644, -1.9166266 , 9.54532 ], + [ 2.1093285 , 0.19688708, -1.1634291 ]]) + >>> w, v = LA.eigh(a, UPLO='L') + >>> w + array([-2.175445 , -1.4581827, 7.3725457]) + >>> v + array([[ 0.1805163 , -0.16569263, 0.9695154 ], + [ 0.8242942 , 0.56326365, -0.05721384], + [-0.53661287, 0.80949366, 0.23825769]]) + """ + w, v = _npi.eigh(a, UPLO) + return (w, v) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 82b46d21d64f..456a4b897bbd 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -22,7 +22,8 @@ from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback_linalg -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', + 'eigvals', 'eig', 'eigvalsh', 'eigh'] __all__ += fallback_linalg.__all__ @@ -598,3 +599,265 @@ def tensorsolve(a, b, axes=None): True """ return _mx_nd_np.linalg.tensorsolve(a, b, axes) + + +def eigvals(a): + r""" + Compute the eigenvalues of a general matrix. + + Main difference between `eigvals` and `eig`: the eigenvectors aren't + returned. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues will be computed. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues, each repeated according to its multiplicity. + They are not necessarily ordered. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + This function differs from the original `numpy.linalg.eigvals + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + Illustration, using the fact that the eigenvalues of a diagonal matrix + are its diagonal elements, that multiplying a matrix on the left + by an orthogonal matrix, `Q`, and on the right by `Q.T` (the transpose + of `Q`), preserves the eigenvalues of the "middle" matrix. In other words, + if `Q` is orthogonal, then ``Q * A * Q.T`` has the same eigenvalues as + ``A``: + >>> from numpy import linalg as LA + >>> x = np.random.random() + >>> Q = np.array([[np.cos(x), -np.sin(x)], [np.sin(x), np.cos(x)]]) + >>> LA.norm(Q[0, :]), LA.norm(Q[1, :]), np.dot(Q[0, :],Q[1, :]) + (1.0, 1.0, 0.0) + + Now multiply a diagonal matrix by ``Q`` on one side and by ``Q.T`` on the other: + >>> D = np.diag((-1,1)) + >>> LA.eigvals(D) + array([-1., 1.]) + >>> A = np.dot(Q, D) + >>> A = np.dot(A, Q.T) + >>> LA.eigvals(A) + array([ 1., -1.]) # random + """ + return _mx_nd_np.linalg.eigvals(a) + + +def eigvalsh(a, UPLO='L'): + r""" + Compute the eigenvalues real symmetric matrix. + + Main difference from eigh: the eigenvectors are not computed. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The eigenvalues are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigvalsh + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[ 5.4119368 , 8.996273 , -5.086096 ], + [ 0.8866155 , 1.7490431 , -4.6107802 ], + [-0.08034172, 4.4172044 , 1.4528792 ]]) + >>> LA.eigvalsh(a, UPLO='L') + array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order + """ + return _mx_nd_np.linalg.eigvalsh(a, UPLO) + + +def eig(a): + r""" + Compute the eigenvalues and right eigenvectors of a square array. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrices for which the eigenvalues and right eigenvectors will + be computed + + Returns + ------- + w : (..., M) ndarray + The eigenvalues, each repeated according to its multiplicity. + The eigenvalues are not necessarily ordered. + v : (..., M, M) ndarray + The normalized (unit "length") eigenvectors, such that the + column ``v[:,i]`` is the eigenvector corresponding to the + eigenvalue ``w[i]``. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + The number `w` is an eigenvalue of `a` if there exists a vector + `v` such that ``dot(a,v) = w * v``. Thus, the arrays `a`, `w`, and + `v` satisfy the equations ``dot(a[:,:], v[:,i]) = w[i] * v[:,i]`` + for :math:`i \\in \\{0,...,M-1\\}`. + + The array `v` of eigenvectors may not be of maximum rank, that is, some + of the columns may be linearly dependent, although round-off error may + obscure that fact. If the eigenvalues are all different, then theoretically + the eigenvectors are linearly independent. + + This function differs from the original `numpy.linalg.eig + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[-1.9147992 , 6.054115 , 18.046988 ], + [ 0.77563655, -4.860152 , 2.1012988 ], + [ 2.6083658 , 2.3705218 , 0.3192524 ]]) + >>> w, v = LA.eig(a) + >>> w + array([ 6.9683027, -7.768063 , -5.655937 ]) + >>> v + array([[ 0.90617794, 0.9543622 , 0.2492316 ], + [ 0.13086087, -0.04077047, -0.9325615 ], + [ 0.4021404 , -0.29585576, 0.26117516]]) + """ + return _mx_nd_np.linalg.eig(a) + + +def eigh(a, UPLO='L'): + r""" + Return the eigenvalues and eigenvectors real symmetric matrix. + + Returns two objects, a 1-D array containing the eigenvalues of `a`, and + a 2-D square array or matrix (depending on the input type) of the + corresponding eigenvectors (in columns). + + Parameters + ---------- + a : (..., M, M) ndarray + real symmetric matrices whose eigenvalues and eigenvectors are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + v : {(..., M, M) ndarray, (..., M, M) matrix} + The column ``v[:, i]`` is the normalized eigenvector corresponding + to the eigenvalue ``w[i]``. Will return a matrix object if `a` is + a matrix object. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + The eigenvalues/eigenvectors are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigh + `_ in + the following way(s): + - Does not support complex input and output. + + Examples + -------- + >>> from numpy import linalg as LA + >>> a = np.array([[ 6.8189726 , -3.926585 , 4.3990498 ], + [-0.59656644, -1.9166266 , 9.54532 ], + [ 2.1093285 , 0.19688708, -1.1634291 ]]) + >>> w, v = LA.eigh(a, UPLO='L') + >>> w + array([-2.175445 , -1.4581827, 7.3725457]) + >>> v + array([[ 0.1805163 , -0.16569263, 0.9695154 ], + [ 0.8242942 , 0.56326365, -0.05721384], + [-0.53661287, 0.80949366, 0.23825769]]) + """ + return _mx_nd_np.linalg.eigh(a, UPLO) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 56944facac81..ac6c2a2c448a 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -149,6 +149,10 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'linalg.tensorinv', 'linalg.tensorsolve', 'linalg.pinv', + 'linalg.eigvals', + 'linalg.eig', + 'linalg.eigvalsh', + 'linalg.eigh', 'shape', 'trace', 'tril', diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 979742001aa8..2aa68852fc6b 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,9 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', + 'eigvals', 'eig', 'eigvalsh', 'eigh'] + def pinv(a, rcond=1e-15, hermitian=False): r""" @@ -567,3 +569,205 @@ def tensorsolve(a, b, axes=None): True """ return _npi.tensorsolve(a, b, axes) + + +def eigvals(a): + r""" + Compute the eigenvalues of a general matrix. + + Main difference between `eigvals` and `eig`: the eigenvectors aren't + returned. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues will be computed. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues, each repeated according to its multiplicity. + They are not necessarily ordered. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + This function differs from the original `numpy.linalg.eigvals + `_ in + the following way(s): + - Does not support complex input and output. + """ + return _npi.eigvals(a) + + +def eigvalsh(a, UPLO='L'): + r""" + Compute the eigenvalues real symmetric matrix. + + Main difference from eigh: the eigenvectors are not computed. + + Parameters + ---------- + a : (..., M, M) ndarray + A real-valued matrix whose eigenvalues are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M,) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The eigenvalues are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigvalsh + `_ in + the following way(s): + - Does not support complex input and output. + """ + return _npi.eigvalsh(a, UPLO) + + +def eig(a): + r""" + Compute the eigenvalues and right eigenvectors of a square array. + + Parameters + ---------- + a : (..., M, M) ndarray + Matrices for which the eigenvalues and right eigenvectors will + be computed + + Returns + ------- + w : (..., M) ndarray + The eigenvalues, each repeated according to its multiplicity. + The eigenvalues are not necessarily ordered. + v : (..., M, M) ndarray + The normalized (unit "length") eigenvectors, such that the + column ``v[:,i]`` is the eigenvector corresponding to the + eigenvalue ``w[i]``. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eigvals : eigenvalues of a non-symmetric array. + eigh : eigenvalues and eigenvectors of a real symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + This is implemented using the ``_geev`` LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. + + The number `w` is an eigenvalue of `a` if there exists a vector + `v` such that ``dot(a,v) = w * v``. Thus, the arrays `a`, `w`, and + `v` satisfy the equations ``dot(a[:,:], v[:,i]) = w[i] * v[:,i]`` + for :math:`i \\in \\{0,...,M-1\\}`. + + The array `v` of eigenvectors may not be of maximum rank, that is, some + of the columns may be linearly dependent, although round-off error may + obscure that fact. If the eigenvalues are all different, then theoretically + the eigenvectors are linearly independent. + + This function differs from the original `numpy.linalg.eig + `_ in + the following way(s): + - Does not support complex input and output. + """ + return _npi.eig(a) + + +def eigh(a, UPLO='L'): + r""" + Return the eigenvalues and eigenvectors real symmetric matrix. + + Returns two objects, a 1-D array containing the eigenvalues of `a`, and + a 2-D square array or matrix (depending on the input type) of the + corresponding eigenvectors (in columns). + + Parameters + ---------- + a : (..., M, M) ndarray + real symmetric matrices whose eigenvalues and eigenvectors are to be computed. + UPLO : {'L', 'U'}, optional + Specifies whether the calculation is done with the lower triangular + part of `a` ('L', default) or the upper triangular part ('U'). + Irrespective of this value only the real parts of the diagonal will + be considered in the computation to preserve the notion of a Hermitian + matrix. It therefore follows that the imaginary part of the diagonal + will always be treated as zero. + + Returns + ------- + w : (..., M) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + v : {(..., M, M) ndarray, (..., M, M) matrix} + The column ``v[:, i]`` is the normalized eigenvector corresponding + to the eigenvalue ``w[i]``. Will return a matrix object if `a` is + a matrix object. + + Raises + ------ + MXNetError + If the eigenvalue computation does not converge. + + See Also + -------- + eig : eigenvalues and right eigenvectors of general arrays + eigvals : eigenvalues of a non-symmetric array. + eigvalsh : eigenvalues of a real symmetric. + + Notes + ----- + The eigenvalues/eigenvectors are computed using LAPACK routines ``_syevd``. + + This function differs from the original `numpy.linalg.eigh + `_ in + the following way(s): + - Does not support complex input and output. + """ + return _npi.eigh(a, UPLO) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c60e5bc22201..6fc38b84a1d2 100755 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -2483,3 +2483,49 @@ def _init_weight(self, name, arr): if numpy_func is not None: numpy_out = numpy_func(*[ele.asnumpy() for ele in data_l]) assert_almost_equal(saved_out_np, numpy_out, rtol=rtol, atol=atol) + + +def new_matrix_with_real_eigvals_2d(n): + """Generate a well-conditioned matrix with small real eigenvalues.""" + shape = (n, n) + q = np.ones(shape) + while 1: + D = np.diag(np.random.uniform(-1.0, 1.0, shape[-1])) + I = np.eye(shape[-1]).reshape(shape) + v = np.random.uniform(-1., 1., shape[-1]).reshape(shape[:-1] + (1,)) + v = v / np.linalg.norm(v, axis=-2, keepdims=True) + v_T = np.swapaxes(v, -1, -2) + U = I - 2 * np.matmul(v, v_T) + q = np.matmul(U, D) + if (np.linalg.cond(q, 2) < 3): + break + D = np.diag(np.random.uniform(-10.0, 10.0, n)) + q_inv = np.linalg.inv(q) + return np.matmul(np.matmul(q_inv, D), q) + + +def new_matrix_with_real_eigvals_nd(shape): + """Generate well-conditioned matrices with small real eigenvalues.""" + n = int(np.prod(shape[:-2])) if len(shape) > 2 else 1 + return np.array([new_matrix_with_real_eigvals_2d(shape[-1]) for i in range(n)]).reshape(shape) + + +def new_orthonormal_matrix_2d(n): + """Generate a orthonormal matrix.""" + x = np.random.randn(n, n) + x_trans = x.T + sym_mat = np.matmul(x_trans, x) + return np.linalg.qr(sym_mat)[0] + + +def new_sym_matrix_with_real_eigvals_2d(n): + """Generate a sym matrix with real eigenvalues.""" + q = new_orthonormal_matrix_2d(n) + D = np.diag(np.random.uniform(-10.0, 10.0, n)) + return np.matmul(np.matmul(q.T, D), q) + + +def new_sym_matrix_with_real_eigvals_nd(shape): + """Generate sym matrices with real eigenvalues.""" + n = int(np.prod(shape[:-2])) if len(shape) > 2 else 1 + return np.array([new_sym_matrix_with_real_eigvals_2d(shape[-1]) for i in range(n)]).reshape(shape) diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc index 442789e95d13..57e83b1ddb2d 100644 --- a/src/operator/c_lapack_api.cc +++ b/src/operator/c_lapack_api.cc @@ -78,6 +78,16 @@ return 1; \ } + #define MXNET_LAPACK_CWRAPPER8(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \ + int n, dtype *a, int lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int ldvl, dtype *vr, int ldvr, \ + dtype *work, int lwork) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ + } + #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ dtype *a, int lda, dtype *s, \ @@ -121,6 +131,9 @@ MXNET_LAPACK_CWRAPPER7(sgesv, float) MXNET_LAPACK_CWRAPPER7(dgesv, double) + MXNET_LAPACK_CWRAPPER8(sgeev, float) + MXNET_LAPACK_CWRAPPER8(dgeev, double) + MXNET_LAPACK_CWRAPPER9(sgesdd, float) MXNET_LAPACK_CWRAPPER9(dgesdd, double) diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h index 8029a71f61d3..8b07265ba299 100644 --- a/src/operator/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -180,6 +180,23 @@ extern "C" { MXNET_LAPACK_FSIG_GESDD(sgesdd, float) MXNET_LAPACK_FSIG_GESDD(dgesdd, double) + + #ifdef __ANDROID__ + #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ + int func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int *ldvl, dtype *vr, int *ldvr, \ + dtype *work, int *lwork, int *info); + #else + #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ + void func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int *ldvl, dtype *vr, int *ldvr, \ + dtype *work, int *lwork, int *info); + #endif + + MXNET_LAPACK_FSIG_GEEV(sgeev, float) + MXNET_LAPACK_FSIG_GEEV(dgeev, double) } #endif // MSHADOW_USE_MKL == 0 @@ -329,6 +346,22 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAP_GETRI(s, float) MXNET_LAPACK_CWRAP_GETRI(d, double) + #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \ + int n, dtype *a, int lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int ldvl, dtype *vr, int ldvr, \ + dtype *work, int lwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##geev(matrix_layout, jobvl, jobvr, \ + n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \ + } \ + *work = 0; \ + return 0; \ + } + MXNET_LAPACK_CWRAP_GEEV(s, float) + MXNET_LAPACK_CWRAP_GEEV(d, double) + #elif MXNET_USE_LAPACK #define MXNET_LAPACK_ROW_MAJOR 101 @@ -475,6 +508,25 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAP_GESDD(sgesdd, float) MXNET_LAPACK_CWRAP_GESDD(dgesdd, double) + #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \ + int n, dtype *a, int lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int ldvl, dtype *vr, int ldvr, \ + dtype *work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + CHECK(false) << "MXNET_LAPACK_" << #prefix << "geev implemented for col-major layout only"; \ + return 1; \ + } else { \ + int info(0); \ + prefix##geev_(&jobvl, &jobvr, \ + &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info); \ + return info; \ + } \ + } + MXNET_LAPACK_CWRAP_GEEV(s, float) + MXNET_LAPACK_CWRAP_GEEV(d, double) + #define MXNET_LAPACK // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format @@ -560,6 +612,13 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ int lda, int *ipiv, dtype *b, int ldb); \ + #define MXNET_LAPACK_CWRAPPER8(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \ + int n, dtype *a, int lda, \ + dtype *wr, dtype *wi, \ + dtype *vl, int ldvl, dtype *vr, int ldvr, \ + dtype *work, int lwork); \ + #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ dtype *a, int lda, dtype *s, \ @@ -597,6 +656,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAPPER7(sgesv, float) MXNET_LAPACK_CWRAPPER7(dgesv, double) + MXNET_LAPACK_CWRAPPER8(sgeev, float) + MXNET_LAPACK_CWRAPPER8(dgeev, double) + MXNET_LAPACK_CWRAPPER9(sgesdd, float) MXNET_LAPACK_CWRAPPER9(dgesdd, double) diff --git a/src/operator/numpy/linalg/np_eig-inl.h b/src/operator/numpy/linalg/np_eig-inl.h new file mode 100644 index 000000000000..65ca22779369 --- /dev/null +++ b/src/operator/numpy/linalg/np_eig-inl.h @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_eig-inl.h + * \brief Placeholder for eig + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_EIG_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_EIG_INL_H_ + +#include +#include "./np_eigvals-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +template +struct eigvec_assign_helper { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_data, DType *out_data, + const int nrow, const int ld, const int step) { + int idx = i / step, row = (i % step) / ld, col = (i % step) % ld; + KERNEL_ASSIGN(out_data[idx * step + row + col * ld], req, in_data[i]); + } +}; + +// Calculates workspace size of eig forward op. +// The dimension of the array WORK in LAPACKE_#GEEV should >= max(1,3*N), and +// if JOBVL = 'V' or JOBVR = 'V', LWORK >= 4*N. +// For good performance, LWORK must generally be larger. +template +size_t EigForwardWorkspaceSize(const TBlob& a, + const TBlob& w, + const TBlob& v, + const std::vector& req) { + if (kNullOp == req[0] && kNullOp == req[1]) { return 0U; } + + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + size_t work_space_size = 0; + size_t n = a.size(a.ndim() - 1); + work_space_size += a.Size(); // For matrix. + work_space_size += 2 * w.Size(); // For eigenvalues' real and image component. + work_space_size += n * n; // For left eigenvectors temp memory + work_space_size += v.Size(); // For right eigenvectors real and image component. + work_space_size += 4 * n; // For workspace size in LAPACKE_#GEEV. + work_space_size *= sizeof(DType); + return work_space_size; + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + +template +void EigOpForwardImpl(const TBlob& a, + const TBlob& w, + const TBlob& v, + const std::vector& req, + std::vector *workspace, + mshadow::Stream *s) { + if (kNullOp == req[0] && kNullOp == req[1]) { return; } + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& w_shape = w.shape_; + const mxnet::TShape& v_shape = v.shape_; + const int a_ndim = a_shape.ndim(); + + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + const int N = a_shape[a_ndim - 1]; + DType *a_ptr = + reinterpret_cast(workspace->data()); + DType *wr_ptr = + reinterpret_cast(workspace->data() + a.Size() * sizeof(DType)); + DType *wi_ptr = + reinterpret_cast(workspace->data() + (w.Size() + a.Size()) * sizeof(DType)); + DType *vl_ptr = + reinterpret_cast(workspace->data() + (2 * w.Size() + a.Size()) * sizeof(DType)); + DType *vr_ptr = + reinterpret_cast( + workspace->data() + (2 * w.Size() + N * N + a.Size()) * sizeof(DType)); + DType *work_ptr = + reinterpret_cast( + workspace->data() + (2 * w.Size() + v.Size() + N * N + a.Size()) * sizeof(DType)); + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + // Cast type and transpose. + mxnet_op::Kernel::Launch( + s, a_shape.Size(), a.dptr(), a_ptr, N, N, N * N); + }); + char jobvl = 'N', jobvr = 'V'; + mxnet::TBlob a_trans_data(a_ptr, a_shape, a.dev_mask(), a.dev_id()); + mxnet::TBlob wr_data(wr_ptr, w_shape, w.dev_mask(), w.dev_id()); + mxnet::TBlob wi_data(wi_ptr, w_shape, w.dev_mask(), w.dev_id()); + mxnet::TBlob vl_data(vl_ptr, Shape3(1, N, N), v.dev_mask(), v.dev_id()); + mxnet::TBlob vr_data(vr_ptr, v_shape, v.dev_mask(), v.dev_id()); + mxnet::TBlob work_data(work_ptr, Shape1(4 * N), a.dev_mask(), a.dev_id()); + eig_eigvals::op(jobvl, jobvr, + a_trans_data.FlatToKD(s), + wr_data.FlatToKD(s), + wi_data.FlatToKD(s), + vl_data.get(s), + vr_data.FlatToKD(s), + work_data.get(s)); + for (size_t i = 0; i < wi_data.Size(); ++i) { + CHECK_LE(fabs(wi_ptr[i]), 1e-15) + << "Complex eigvals is unsupported in linalg temporary."; + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, w.Size(), wr_ptr, w.dptr()); + }); + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, v.Size(), vr_ptr, v.dptr(), N, N, N * N); + }); + }); +} + +template +void GpuCallbackCpuImpl(const TBlob& a, + const TBlob& w, + const TBlob& v, + AType* a_cp_ptr, + WType* w_cp_ptr, + WType* v_cp_ptr, + std::vector* workspace, + const OpContext& ctx, + const std::vector& req) { +#if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + cudaStream_t stream = Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync(a_cp_ptr, a.dptr(), sizeof(AType) * a.Size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync(w_cp_ptr, w.dptr(), sizeof(WType) * w.Size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync(v_cp_ptr, v.dptr(), sizeof(WType) * v.Size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + mxnet::TBlob a_data(a_cp_ptr, a.shape_, cpu::kDevMask); + mxnet::TBlob w_data(w_cp_ptr, w.shape_, cpu::kDevMask); + mxnet::TBlob v_data(v_cp_ptr, v.shape_, cpu::kDevMask); + // Op forward implement on cpu. + EigOpForwardImpl(a_data, w_data, v_data, req, workspace, ctx.get_stream()); + // Copy back to gpu. + CUDA_CALL(cudaMemcpyAsync(w.dptr(), w_cp_ptr, sizeof(WType) * w.Size(), + cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaMemcpyAsync(v.dptr(), v_cp_ptr, sizeof(WType) * v.Size(), + cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); +#else + LOG(FATAL) << "Please build with USE_CUDA=1 to enable GPU"; +#endif // MXNET_USE_CUDA +} + +template +void EigOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + const TBlob& a = inputs[0]; + const TBlob& w = outputs[0]; + const TBlob& v = outputs[1]; + + // Calculate workspace size. + size_t workspace_size = EigForwardWorkspaceSize(a, w, v, req); + std::vector workspace(workspace_size, 0); + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, WType, { + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + if (xpu::kDevCPU) { + // Op forward implement. + EigOpForwardImpl(a, w, v, req, &workspace, ctx.get_stream()); + } else { + std::vector a_vec(a.Size(), 0); + std::vector w_vec(w.Size(), 0); + std::vector v_vec(v.Size(), 0); + AType* a_cp_ptr = a_vec.data(); + WType* w_cp_ptr = w_vec.data(); + WType* v_cp_ptr = v_vec.data(); + GpuCallbackCpuImpl(a, w, v, a_cp_ptr, w_cp_ptr, v_cp_ptr, &workspace, ctx, req); + } + }); + }); +} + +struct EighParam : public dmlc::Parameter { + char UPLO; + DMLC_DECLARE_PARAMETER(EighParam) { + DMLC_DECLARE_FIELD(UPLO) + .set_default('L') + .describe("Specifies whether the calculation is done with the lower or upper triangular part."); + } +}; + +template +void EighOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + const TBlob& a = inputs[0]; + const TBlob& w = outputs[0]; + const TBlob& v = outputs[1]; + const char UPLO = nnvm::get(attrs.parsed).UPLO; + Stream *s = ctx.get_stream(); + + if (kNullOp == req[0] && kNullOp == req[0]) { return; } + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + // Calculate workspace size. + size_t workspace_size = EighEigvalshForwardWorkspaceSize(a, w, req, ctx); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + + EighEigvalshOpForwardImpl(a, w, UPLO, attrs, ctx, req, workspace); + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + DType *a_ptr = reinterpret_cast(workspace.dptr_); + DType *w_ptr = reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)); + TBlob a_data(a_ptr, a.shape_, a.dev_mask(), a.dev_id()); + Tensor A = a_data.FlatToKD(s); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, w.Size(), w_ptr, w.dptr()); + }); + + // Set signs of eigenvectors in a deterministic way + mxnet_op::Kernel::Launch( + s, A.size(0) * A.size(1), A.size(1), A.dptr_, A.stride_); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, v.Size(), a_ptr, v.dptr(), + A.size(1), A.stride_, + A.size(1) * A.stride_); + }); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_EIG_INL_H_ diff --git a/src/operator/numpy/linalg/np_eig.cc b/src/operator/numpy/linalg/np_eig.cc new file mode 100644 index 000000000000..8a7fc0c3d606 --- /dev/null +++ b/src/operator/numpy/linalg/np_eig.cc @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_eig.cc + * \brief CPU implementation placeholder of Eig Operator + */ +#include "./np_eig-inl.h" + +namespace mxnet { +namespace op { + +// Inputs: A. +// Outputs: Eig, EigVec +bool EigOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + const mxnet::TShape& a_shape = (*in_attrs)[0]; + const mxnet::TShape& eig_shape = (*out_attrs)[0]; + const mxnet::TShape& eigv_shape = (*out_attrs)[1]; + + if (shape_is_known(a_shape)) { + // Forward shape inference. + const int a_ndim = a_shape.ndim(); + CHECK_GE(a_ndim, 2) + << "Array must be at least two-dimensional"; + CHECK_EQ(a_shape[a_ndim - 2], a_shape[a_ndim - 1]) + << "Input A's last two dimension must be equal"; + + // Calculate eig shape. + std::vector eig_shape_vec(a_ndim - 1, -1); + for (int i = 0; i < a_ndim - 1; ++i) { + eig_shape_vec[i] = a_shape[i]; + } + mxnet::TShape eig_shape(eig_shape_vec.begin(), eig_shape_vec.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, eig_shape); + // Calculate eig vec shape: must have the same shape as A + SHAPE_ASSIGN_CHECK(*out_attrs, 1, a_shape); + } else { + // Backward shape inference. + if (shape_is_known(eig_shape) && shape_is_known(eigv_shape)) { + const int eig_ndim = eig_shape.ndim(); + const int eigv_ndim = eigv_shape.ndim(); + CHECK_GE(eigv_ndim, 2) + << "Outputs V must be at least two-dimensional"; + CHECK_EQ(eigv_shape[eigv_ndim - 2], eigv_shape[eigv_ndim - 1]) + << "Outputs V's last two dimension must be equal"; + CHECK_EQ(eig_ndim + 1, eigv_ndim) + << "Outputs W, V must satisfy W.ndim == V.ndim - 1"; + for (int i = 0; i < eig_ndim; ++i) { + CHECK_EQ(eig_shape[i], eigv_shape[i]) + << "Outputs W, V's shape dismatch"; + } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, eigv_shape); + } else if (shape_is_known(eig_shape)) { + const int eig_ndim = eig_shape.ndim(); + CHECK_GE(eig_ndim, 1) + << "Outputs W must be at least one-dimensional"; + std::vector eigv_shape_vec(eig_ndim + 1); + for (int i = 0; i < eig_ndim; ++i) { + eigv_shape_vec[i] = eig_shape[i]; + } + eigv_shape_vec[eig_ndim] = eig_shape[eig_ndim - 1]; + mxnet::TShape eigv_shape(eigv_shape_vec.begin(), eigv_shape_vec.end()); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, eigv_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, eigv_shape); + } else { + const int eigv_ndim = eigv_shape.ndim(); + CHECK_GE(eigv_ndim, 2) + << "Outputs V must be at least two-dimensional"; + CHECK_EQ(eigv_shape[eigv_ndim - 2], eigv_shape[eigv_ndim - 1]) + << "Outputs V's last two dimension must be equal"; + std::vector eig_shape_vec(eigv_ndim - 1); + for (int i = 0; i < eigv_ndim - 1; ++i) { + eig_shape_vec[i] = eigv_shape[i]; + } + mxnet::TShape eig_shape(eig_shape_vec.begin(), eig_shape_vec.end()); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, eigv_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, eig_shape); + } + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +inline bool EigOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + int a_type = in_attrs->at(0); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg"; + if (mshadow::kFloat32 == a_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat64); + } + return out_attrs->at(0) != -1 && out_attrs->at(1) != -1; +} + +NNVM_REGISTER_OP(_npi_eig) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A"}; +}) +.set_attr("FInferShape", EigOpShape) +.set_attr("FInferType", EigOpType) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", EigOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +DMLC_REGISTER_PARAMETER(EighParam); + +NNVM_REGISTER_OP(_npi_eigh) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A"}; +}) +.set_attr("FInferShape", EigOpShape) +.set_attr("FInferType", EigOpType) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", EighOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of real matrices") +.add_arguments(EighParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_eig.cu b/src/operator/numpy/linalg/np_eig.cu new file mode 100644 index 000000000000..c0184ad221d5 --- /dev/null +++ b/src/operator/numpy/linalg/np_eig.cu @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_eigvals.cu + * \brief GPU implementation placeholder of Eigvals Operator + */ + +#include +#include "./np_eig-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_eig) +.set_attr("FCompute", EigOpForward); + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(_npi_eigh) +.set_attr("FCompute", EighOpForward); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_eigvals-inl.h b/src/operator/numpy/linalg/np_eigvals-inl.h new file mode 100644 index 000000000000..66b3b47eae84 --- /dev/null +++ b/src/operator/numpy/linalg/np_eigvals-inl.h @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_eigvals-inl.h + * \brief Placeholder for eigvals + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_EIGVALS_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_EIGVALS_INL_H_ + +#include +#include +#include "../../operator_common.h" +#include "../../mshadow_op.h" +#include "../../tensor/la_op.h" +#include "../../tensor/la_op-inl.h" +#include "./np_solve-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +template +struct eigvals_assign_helper { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_data, DType *out_data) { + KERNEL_ASSIGN(out_data[i], req, in_data[i]); + } +}; + +struct eigh_eigvalsh_helper { + template + MSHADOW_XINLINE static void Map(int i, const InDType *in_data, OutDType *out_data, + const int nrow, const int ld, const int step, bool USE_UP) { + int idx = i / step, row = (i % step) / ld, col = (i % step) % ld; + if ((USE_UP && row > col) || (!USE_UP && row < col)) { + out_data[idx * step + col + row * ld] = + static_cast(in_data[idx * step + row + col * ld]); + } else { + out_data[idx * step + col + row * ld] = + static_cast(in_data[idx * step + col + row * ld]); + } + } +}; + +template +void linalg_geev(char jobvl, + char jobvr, + const Tensor& a, + const Tensor& wr, + const Tensor& wi, + const Tensor& vl, + const Tensor& vr, + const Tensor& work_array); + +#define LINALG_CPU_EIG(fname, DType) \ +template<> inline \ +void linalg_geev(char jobvl, \ + char jobvr, \ + const Tensor& a, \ + const Tensor& wr, \ + const Tensor& wi, \ + const Tensor& vl, \ + const Tensor& vr, \ + const Tensor& work_array) { \ + const int n = a.size(1), lda = a.size(0); \ + const int lwork = work_array.shape_.Size(); \ + const int ldvl = vl.size(0), ldvr = vr.size(0); \ + int res(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, jobvl, jobvr, \ + n, a.dptr_, lda, \ + wr.dptr_, wi.dptr_, \ + vl.dptr_, ldvl, \ + vr.dptr_, ldvr, \ + work_array.dptr_, lwork)); \ + CHECK_LE(res, 0) << #fname << "the QR algorithm failed to compute all the" \ + << "eigenvalues, and no eigenvectors have been computed; elements " \ + << res + 1 << ":N" << " of WR and WI contain eigenvalues which have converged"; \ + CHECK_GE(res, 0) << #fname << ": the " << -res \ + << "-th argument had an illegal value"; \ +} + +LINALG_CPU_EIG(sgeev, float) +LINALG_CPU_EIG(dgeev, double) + +#ifdef __CUDACC__ + +#define LINALG_GPU_EIG(fname, DType) \ +template<> inline \ +void linalg_geev(char jobvl, \ + char jobvr, \ + const Tensor& a, \ + const Tensor& wr, \ + const Tensor& wi, \ + const Tensor& vl, \ + const Tensor& vr, \ + const Tensor& work_array) { \ + LOG(FATAL) << "Lapack _geev routines in gpu is unsupported"; \ +} + +LINALG_GPU_EIG(sgeev, float) +LINALG_GPU_EIG(dgeev, double) + +#endif // __CUDACC__ + +struct eig_eigvals { + template + static void op(char jobvl, + char jobvr, + const Tensor& a, + const Tensor& wr, + const Tensor& wi, + const Tensor& vl, + const Tensor& vr, + const Tensor& work_array) { + const mxnet::TShape& a_shape = a.shape_; + const int a_ndim = a_shape.ndim(); + if (jobvl == 'N' && jobvr == 'N') { + CHECK_GE(work_array.shape_.Size(), 3 * a.shape_[a_ndim - 1]) + << "The dimension of the array WORK in LAPACKE_#GEEV should >= max(1,3*N)."; + } else { + CHECK_GE(work_array.shape_.Size(), 4 * a.shape_[a_ndim - 1]) + << "If JOBVL = 'V' or JOBVR = 'V', " + << "the dimension of the array WORK in LAPACKE_#GEEV should >= 4*N."; + } + for (int i = 0; i < a_shape[0]; ++i) { + if (jobvl == 'N' && jobvr == 'N') { + linalg_geev(jobvl, jobvr, a[i], wr[i], wi[i], vl[0], vr[0], work_array); + } else if (jobvl == 'N' && jobvr == 'V') { + linalg_geev(jobvl, jobvr, a[i], wr[i], wi[i], vl[0], vr[i], work_array); + } else if (jobvl == 'V' && jobvr == 'N') { + linalg_geev(jobvl, jobvr, a[i], wr[i], wi[i], vl[i], vr[0], work_array); + } else { + linalg_geev(jobvl, jobvr, a[i], wr[i], wi[i], vl[i], vr[i], work_array); + } + } + } +}; + +// Calculates workspace size of eigvals forward op. +// The dimension of the array WORK in LAPACKE_#GEEV should >= max(1,3*N), and +// if JOBVL = 'V' or JOBVR = 'V', LWORK >= 4*N. +// For good performance, LWORK must generally be larger. +template +size_t EigvalsForwardWorkspaceSize(const TBlob& a, + const TBlob& w, + const std::vector& req) { + if (kNullOp == req[0]) { return 0U; } + + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + size_t work_space_size = 0; + size_t n = a.size(a.ndim() - 1); + work_space_size += a.Size(); // For matrix. + work_space_size += 2 * w.Size(); // For eigenvalues' real and image component. + work_space_size += 2 * n * n; // For left and right eigenvectors temp memory + work_space_size += 3 * n; // For workspace size in LAPACKE_#GEEV. + work_space_size *= sizeof(DType); + return work_space_size; + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + +template +void EigvalsOpForwardImpl(const TBlob& a, + const TBlob& w, + const std::vector& req, + std::vector *workspace, + mshadow::Stream *s) { + if (kNullOp == req[0]) { return; } + const mxnet::TShape& a_shape = a.shape_; + const mxnet::TShape& w_shape = w.shape_; + const int a_ndim = a_shape.ndim(); + + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + const int N = a.size(a_ndim - 1); + DType *a_ptr = + reinterpret_cast(workspace->data()); + DType *wr_ptr = + reinterpret_cast(workspace->data() + a.Size() * sizeof(DType)); + DType *wi_ptr = + reinterpret_cast(workspace->data() + (w.Size() + a.Size()) * sizeof(DType)); + DType *vl_ptr = + reinterpret_cast(workspace->data() + (2 * w.Size() + a.Size()) * sizeof(DType)); + DType *vr_ptr = + reinterpret_cast( + workspace->data() + (N * N + 2 * w.Size() + a.Size()) * sizeof(DType)); + DType *work_ptr = + reinterpret_cast( + workspace->data() + (2 * (N * N + w.Size()) + a.Size()) * sizeof(DType)); + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + // Cast type and transpose. + mxnet_op::Kernel::Launch( + s, a.Size(), a.dptr(), a_ptr, N, N, N * N); + }); + s->Wait(); + char jobvl = 'N', jobvr = 'N'; + mxnet::TBlob a_trans_data(a_ptr, a_shape, a.dev_mask(), a.dev_id()); + mxnet::TBlob wr_data(wr_ptr, w_shape, w.dev_mask(), w.dev_id()); + mxnet::TBlob wi_data(wi_ptr, w_shape, w.dev_mask(), w.dev_id()); + mxnet::TBlob vl_data(vl_ptr, Shape3(1, N, N), w.dev_mask(), w.dev_id()); + mxnet::TBlob vr_data(vr_ptr, Shape3(1, N, N), w.dev_mask(), w.dev_id()); + mxnet::TBlob work_data(work_ptr, Shape1(3 * N), a.dev_mask(), a.dev_id()); + eig_eigvals::op(jobvl, jobvr, + a_trans_data.FlatToKD(s), + wr_data.FlatToKD(s), + wi_data.FlatToKD(s), + vl_data.get(s), + vr_data.get(s), + work_data.get(s)); + for (size_t i = 0; i < wi_data.Size(); ++i) { + CHECK_LE(fabs(wi_ptr[i]), 1e-15) + << "Complex eigvals is unsupported in linalg temporary."; + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, w.Size(), wr_ptr, w.dptr()); + }); + }); +} + +template +void GpuCallbackCpuImpl(const TBlob& a, + const TBlob& w, + AType* a_cp_ptr, + WType* w_cp_ptr, + std::vector* workspace, + const OpContext& ctx, + const std::vector& req) { +#if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + cudaStream_t stream = Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync(a_cp_ptr, a.dptr(), sizeof(AType) * a.Size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync(w_cp_ptr, w.dptr(), sizeof(WType) * w.Size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + mxnet::TBlob a_data(a_cp_ptr, a.shape_, cpu::kDevMask); + mxnet::TBlob w_data(w_cp_ptr, w.shape_, cpu::kDevMask); + // Op forward implement on cpu. + EigvalsOpForwardImpl(a_data, w_data, req, workspace, ctx.get_stream()); + // Copy back to gpu. + CUDA_CALL(cudaMemcpyAsync(w.dptr(), w_cp_ptr, sizeof(WType) * w.Size(), + cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); +#else + LOG(FATAL) << "Please build with USE_CUDA=1 to enable GPU"; +#endif // MXNET_USE_CUDA +} + +template +void EigvalsOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& a = inputs[0]; + const TBlob& w = outputs[0]; + + // Calculate workspace size. + size_t workspace_size = EigvalsForwardWorkspaceSize(a, w, req); + std::vector workspace(workspace_size, 0); + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, WType, { + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + if (xpu::kDevCPU) { + // Op forward implement. + EigvalsOpForwardImpl(a, w, req, &workspace, ctx.get_stream()); + } else { + std::vector a_vec(a.Size(), 0); + std::vector w_vec(w.Size(), 0); + AType* a_cp_ptr = a_vec.data(); + WType* w_cp_ptr = w_vec.data(); + GpuCallbackCpuImpl(a, w, a_cp_ptr, w_cp_ptr, &workspace, ctx, req); + } + }); + }); +} + +struct EigvalshParam : public dmlc::Parameter { + char UPLO; + DMLC_DECLARE_PARAMETER(EigvalshParam) { + DMLC_DECLARE_FIELD(UPLO) + .set_default('L') + .describe("Specifies whether the calculation is done with the lower or upper triangular part."); + } +}; + +template +size_t EighEigvalshForwardWorkspaceSize(const TBlob& a, + const TBlob& w, + const std::vector& req, + const OpContext& ctx) { + if (kNullOp == req[0]) { return 0U; } + + // Zero-size input, no need to launch kernel + if (0U == a.Size()) { return 0U; } + + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, WType, { + Tensor a_temp_tensor = a.FlatToKD(s)[0]; + Tensor w_temp_tensor = w.FlatToKD(s)[0]; + size_t work_space_size = 0; + work_space_size += a.Size(); // For matrix. + work_space_size += w.Size(); // For eigenvalues. + if (xpu::kDevCPU) { + std::vector A_data(a_temp_tensor.shape_.Size(), 0); + std::vector W_data(w_temp_tensor.shape_.Size(), 0); + TBlob a_data(A_data.data(), a_temp_tensor.shape_, a.dev_mask(), a.dev_id()); + TBlob w_data(W_data.data(), w_temp_tensor.shape_, w.dev_mask(), w.dev_id()); + work_space_size += // For workspace size in LAPACKE_#SYEVD. + linalg_syevd_workspace_query(a_data.get(s), + w_data.get(s), s); + } else { + Storage::Handle A_data = + Storage::Get()->Alloc(sizeof(WType) * a_temp_tensor.shape_.Size(), Context::GPU()); + Storage::Handle W_data = + Storage::Get()->Alloc(sizeof(WType) * w_temp_tensor.shape_.Size(), Context::GPU()); + TBlob a_data(static_cast(A_data.dptr), + a_temp_tensor.shape_, a.dev_mask(), a.dev_id()); + TBlob w_data(static_cast(W_data.dptr), + w_temp_tensor.shape_, w.dev_mask(), w.dev_id()); + work_space_size += // For workspace size in LAPACKE_#SYEVD. + linalg_syevd_workspace_query(a_data.get(s), + w_data.get(s), s); + Storage::Get()->Free(A_data); + Storage::Get()->Free(W_data); + } + return work_space_size * sizeof(WType); + }); + }); + LOG(FATAL) << "InternalError: cannot reach here"; + return 0U; +} + +struct print_helper { + template + MSHADOW_XINLINE static void Map(int i, const DType *in_data) { + printf("%lf, ", in_data[i]); + } +}; + +template +void EighEigvalshOpForwardImpl(const TBlob& a, + const TBlob& w, + const char& UPLO, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& req, + const Tensor& workspace) { + Stream *s = ctx.get_stream(); + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + const size_t workspace_size = + (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType); + DType *a_ptr = + reinterpret_cast(workspace.dptr_); + DType *w_ptr = + reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)); + DType *work_ptr = + reinterpret_cast(workspace.dptr_ + (a.Size() + w.Size()) * sizeof(DType)); + TBlob a_data(a_ptr, a.shape_, a.dev_mask(), a.dev_id()); + TBlob w_data(w_ptr, w.shape_, w.dev_mask(), w.dev_id()); + TBlob work_data(work_ptr, + Shape1(workspace_size - a.Size() - w.Size()), + w.dev_mask(), w.dev_id()); + Tensor A = a_data.FlatToKD(s); + Tensor W = w_data.FlatToKD(s); + Tensor Work = work_data.get(s); + // Copy used upper triangle part of 'a'. + MSHADOW_TYPE_SWITCH(a.type_flag_, AType, { + mxnet_op::Kernel::Launch( + s, a.Size(), + a.dptr(), a_ptr, + A.size(1), A.stride_, + A.size(1) * A.stride_, + UPLO == 'U'); + }); + // Loop over items in batch + for (index_t i = 0; i < A.size(0); ++i) { + // Input 'a' must be symmetric, only lower triangle is useds + // Needs workspace (both DType and int), size of which is determined by a workspace query + linalg_syevd(A[i], W[i], Work, s); + } + }); +} + +template +void EigvalshOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& a = inputs[0]; + const TBlob& w = outputs[0]; + char UPLO = nnvm::get(attrs.parsed).UPLO; + Stream *s = ctx.get_stream(); + + if (kNullOp == req[0]) { return; } + // Zero-size output, no need to launch kernel + if (0U == a.Size()) { return; } + + // Calculate workspace size. + size_t workspace_size = EighEigvalshForwardWorkspaceSize(a, w, req, ctx); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + + EighEigvalshOpForwardImpl(a, w, UPLO, attrs, ctx, req, workspace); + + MSHADOW_SGL_DBL_TYPE_SWITCH(w.type_flag_, DType, { + const mxnet::TShape w_shape = w.shape_; + const int w_ndim = w_shape.ndim(); + DType *w_ptr = reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)); + if (w_shape[w_ndim - 1] == 3) { + std::cout << "result in workspace:" << std::endl; + mxnet_op::Kernel::Launch(s, w_shape[w_ndim - 1], w_ptr); + std::cout << std::endl; + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, w.Size(), + reinterpret_cast(workspace.dptr_ + a.Size() * sizeof(DType)), + w.dptr()); + }); + if (w_shape[w_ndim - 1] == 3) { + std::cout << "in result:" << std::endl; + mxnet_op::Kernel::Launch(s, w_shape[w_ndim - 1], w.dptr()); + std::cout << std::endl; + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_EIGVALS_INL_H_ diff --git a/src/operator/numpy/linalg/np_eigvals.cc b/src/operator/numpy/linalg/np_eigvals.cc new file mode 100644 index 000000000000..3f931ca8ee78 --- /dev/null +++ b/src/operator/numpy/linalg/np_eigvals.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_eigvals.cc + * \brief CPU implementation placeholder of Eigvals Operator + */ +#include "./np_eigvals-inl.h" + +namespace mxnet { +namespace op { + +// Inputs: A. +// Outputs: Eig. +bool EigvalsOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& a_shape = (*in_attrs)[0]; + const mxnet::TShape& eig_shape = (*out_attrs)[0]; + + if (shape_is_known(a_shape)) { + // Forward shape inference. + const int a_ndim = a_shape.ndim(); + CHECK_GE(a_ndim, 2) + << "Array must be at least two-dimensional"; + CHECK_EQ(a_shape[a_ndim - 2], a_shape[a_ndim - 1]) + << "Input A's last two dimension must be equal"; + + // Calculate eig shape. + std::vector eig_shape_vec(a_ndim - 1, -1); + for (int i = 0; i < a_ndim - 1; ++i) { + eig_shape_vec[i] = a_shape[i]; + } + mxnet::TShape eig_shape(eig_shape_vec.begin(), eig_shape_vec.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, eig_shape); + } else if (shape_is_known(eig_shape)) { + // Backward shape inference. + const int eig_ndim = eig_shape.ndim(); + CHECK_GE(eig_ndim, 1) + << "Outputs W must be at least one-dimensional"; + std::vector a_shape_vec(eig_ndim + 1); + for (int i = 0; i < eig_ndim; ++i) { + a_shape_vec[i] = eig_shape[i]; + } + a_shape_vec[eig_ndim] = eig_shape[eig_ndim - 1]; + mxnet::TShape a_shape(a_shape_vec.begin(), a_shape_vec.end()); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, a_shape); + } + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +inline bool EigvalsOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + int a_type = in_attrs->at(0); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg"; + if (mshadow::kFloat32 == a_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } + return out_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_npi_eigvals) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A"}; +}) +.set_attr("FInferShape", EigvalsOpShape) +.set_attr("FInferType", EigvalsOpType) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", EigvalsOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +DMLC_REGISTER_PARAMETER(EigvalshParam); + +NNVM_REGISTER_OP(_npi_eigvalsh) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A"}; +}) +.set_attr("FInferShape", EigvalsOpShape) +.set_attr("FInferType", EigvalsOpType) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", EigvalshOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix") +.add_arguments(EigvalshParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_eigvals.cu b/src/operator/numpy/linalg/np_eigvals.cu new file mode 100644 index 000000000000..974dedc6172e --- /dev/null +++ b/src/operator/numpy/linalg/np_eigvals.cu @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_eigvals.cu + * \brief GPU implementation placeholder of Eigvals Operator + */ + +#include +#include "./np_eigvals-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_eigvals) +.set_attr("FCompute", EigvalsOpForward); + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(_npi_eigvalsh) +.set_attr("FCompute", EigvalshOpForward); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 1ee116c20959..347e22a93359 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -525,6 +525,30 @@ def _add_workload_linalg_pinv(): OpArgMngr.add_workload('linalg.pinv', np.array(a_np, dtype=dtype), np.array(rcond_np, dtype=dtype), hermitian) +def _add_workload_linalg_eigvals(): + OpArgMngr.add_workload('linalg.eigvals', np.array(_np.diag((0, 0)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigvals', np.array(_np.diag((1, 1)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigvals', np.array(_np.diag((2, 2)), dtype=np.float64)) + + +def _add_workload_linalg_eig(): + OpArgMngr.add_workload('linalg.eig', np.array(_np.diag((0, 0)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eig', np.array(_np.diag((1, 1)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eig', np.array(_np.diag((2, 2)), dtype=np.float64)) + + +def _add_workload_linalg_eigvalsh(): + OpArgMngr.add_workload('linalg.eigvalsh', np.array(_np.diag((0, 0)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigvalsh', np.array(_np.diag((1, 1)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigvalsh', np.array(_np.diag((2, 2)), dtype=np.float64)) + + +def _add_workload_linalg_eigh(): + OpArgMngr.add_workload('linalg.eigh', np.array(_np.diag((0, 0)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigh', np.array(_np.diag((1, 1)), dtype=np.float64)) + OpArgMngr.add_workload('linalg.eigh', np.array(_np.diag((2, 2)), dtype=np.float64)) + + def _add_workload_linalg_slogdet(): OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((2, 2)), dtype=np.float32)) OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) @@ -1745,6 +1769,10 @@ def _prepare_workloads(): _add_workload_linalg_tensorinv() _add_workload_linalg_tensorsolve() _add_workload_linalg_pinv() + _add_workload_linalg_eigvals() + _add_workload_linalg_eig() + _add_workload_linalg_eigvalsh() + _add_workload_linalg_eigh() _add_workload_linalg_slogdet() _add_workload_linalg_cond() _add_workload_trace() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8d54a53b8bc4..71ed67bb2187 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -32,6 +32,8 @@ from mxnet.base import MXNetError from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray from mxnet.test_utils import check_numeric_gradient, use_np, collapse_sum_like +from mxnet.test_utils import new_matrix_with_real_eigvals_nd +from mxnet.test_utils import new_sym_matrix_with_real_eigvals_nd from common import assertRaises, with_seed import random from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf @@ -4559,6 +4561,332 @@ def check_pinv(x, a_np, rcond_np, hermitian, use_rcond): check_pinv(mx_out, a.asnumpy(), rcond.asnumpy(), hermitian, use_rcond) +@with_seed() +@use_np +def test_np_linalg_eigvals(): + class TestEigvals(HybridBlock): + def __init__(self): + super(TestEigvals, self).__init__() + + def hybrid_forward(self, F, a): + return F.np.linalg.eigvals(a) + + def check_eigvals(x, a_np): + try: + x_expected = _np.linalg.eigvals(a_np) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print(e) + else: + assert x.shape == x_expected.shape + if 0 not in x.shape: + n = int(_np.prod(x.shape[:-1])) if len(shape) > 1 else 1 + x = x.reshape(n, -1) + x_expected = x_expected.reshape(n, -1) + for i in range(n): + x1 = _np.sort(x[i].asnumpy()) + x2 = _np.sort(x_expected[i]) + assert_almost_equal(x1, x2, rtol=rtol, atol=atol) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (5, 5), + (1, 0, 0), + (0, 4, 4), + (1, 4, 4), + (2, 4, 4), + (5, 5, 5), + (1, 1, 4, 4), + (2, 3, 4, 4) + ] + dtypes = ['float32', 'float64', 'int32', 'int64'] + UPLOs = ['L', 'U'] + for hybridize in [True, False]: + for shape, dtype in itertools.product(shapes, dtypes): + rtol = 1e-2 if dtype == 'float32' else 1e-3 + atol = 1e-4 if dtype == 'float32' else 1e-5 + test_eigvals = TestEigvals() + if hybridize: + test_eigvals.hybridize() + if 0 in shape: + a_np = _np.ones(shape) + else: + if dtype == 'int32' or dtype == 'int64': + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a_np = _np.array([_np.diag(_np.random.randint(1, 10, size=shape[-1])) for i in range(n)]).reshape(shape) + else: + a_np = new_matrix_with_real_eigvals_nd(shape) + a = np.array(a_np, dtype=dtype) + # check eigvals validity + mx_out = test_eigvals(a) + check_eigvals(mx_out, a.asnumpy()) + # check imperative once again + mx_out = test_eigvals(a) + check_eigvals(mx_out, a.asnumpy()) + + +@with_seed() +@use_np +def test_np_linalg_eigvalsh(): + class TestEigvalsh(HybridBlock): + def __init__(self, UPLO): + super(TestEigvalsh, self).__init__() + self._UPLO = UPLO + + def hybrid_forward(self, F, a): + return F.np.linalg.eigvalsh(a, UPLO=self._UPLO) + + def check_eigvalsh(w, a_np, UPLO): + try: + w_expected = _np.linalg.eigvalsh(a_np, UPLO) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print(e) + else: + assert w.shape == w_expected.shape + assert_almost_equal(w, w_expected, rtol=rtol, atol=atol) + + def new_matrix_from_sym_matrix_nd(sym_a, UPLO): + shape = sym_a.shape + if 0 in shape: + return sym_a + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a = sym_a.reshape(n, shape[-2], shape[-1]) + for idx in range(n): + for i in range(shape[-2]): + for j in range(shape[-1]): + if ((UPLO == 'U' and i > j) or (UPLO == 'L' and i < j)): + a[idx][i][j] = _np.random.uniform(-10., 10.) + return a.reshape(shape) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (5, 5), + (1, 0, 0), + (0, 4, 4), + (1, 4, 4), + (2, 4, 4), + (5, 5, 5), + (1, 1, 4, 4), + (2, 3, 4, 4) + ] + dtypes = ['float32', 'float64', 'int32', 'int64'] + UPLOs = ['L', 'U'] + for hybridize in [True, False]: + for shape, dtype, UPLO in itertools.product(shapes, dtypes, UPLOs): + rtol = 1e-2 if dtype == 'float32' else 1e-3 + atol = 1e-4 if dtype == 'float32' else 1e-5 + test_eigvalsh = TestEigvalsh(UPLO) + if hybridize: + test_eigvalsh.hybridize() + if 0 in shape: + a_np = _np.ones(shape) + else: + if dtype == 'int32' or dtype == 'int64': + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a_np = _np.array([_np.diag(_np.random.randint(1, 10, size=shape[-1])) for i in range(n)], dtype=dtype).reshape(shape) + else: + a_np = new_sym_matrix_with_real_eigvals_nd(shape) + a_np = new_matrix_from_sym_matrix_nd(a_np, UPLO) + a = np.array(a_np, dtype=dtype) + # check eigvalsh validity + mx_out = test_eigvalsh(a) + check_eigvalsh(mx_out, a.asnumpy(), UPLO) + # check imperative once again + mx_out = test_eigvalsh(a) + check_eigvalsh(mx_out, a.asnumpy(), UPLO) + + +@with_seed() +@use_np +def test_np_linalg_eig(): + class TestEig(HybridBlock): + def __init__(self): + super(TestEig, self).__init__() + + def hybrid_forward(self, F, a): + return F.np.linalg.eig(a) + + def check_eig(w, v, a_np): + try: + w_expected, v_expected = _np.linalg.eig(a_np) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print(e) + else: + assert w.shape == w_expected.shape + assert v.shape == v_expected.shape + if 0 not in a_np.shape: + n = int(_np.prod(w.shape[:-1])) if len(shape) > 1 else 1 + N = a_np.shape[-1] + w = w.reshape(n, N) + w_expected = w_expected.reshape(n, N) + v = v.reshape(n, N, N) + v_expected = v_expected.reshape(n, N, N) + a_np = a_np.reshape(n, N, N) + for i in range(n): + # check eigenvector + ai = a_np[i] + vi = (v[i].asnumpy()).T + wi = w[i].asnumpy() + for j in range(N): + assert_almost_equal(wi[j] * vi[j], _np.matmul(ai, vi[j]), rtol=rtol, atol=atol) + + # check eigenvalues + w1 = _np.sort(w[i].asnumpy()) + w2 = _np.sort(w_expected[i]) + assert_almost_equal(w1, w2, rtol=rtol, atol=atol) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (5, 5), + (1, 0, 0), + (0, 4, 4), + (1, 4, 4), + (2, 4, 4), + (5, 5, 5), + (1, 1, 4, 4), + (2, 3, 4, 4) + ] + dtypes = ['float32', 'float64', 'int32', 'int64'] + for hybridize in [True, False]: + for shape, dtype in itertools.product(shapes, dtypes): + rtol = 1e-2 if dtype == 'float32' else 1e-3 + atol = 1e-4 if dtype == 'float32' else 1e-5 + test_eig = TestEig() + if hybridize: + test_eig.hybridize() + if 0 in shape: + a_np = _np.ones(shape) + else: + if dtype == 'int32' or dtype == 'int64': + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a_np = _np.array([_np.diag(_np.random.randint(1, 10, size=shape[-1])) for i in range(n)]).reshape(shape) + else: + a_np = new_matrix_with_real_eigvals_nd(shape) + a = np.array(a_np, dtype=dtype) + # check eig validity + mx_w, mx_v = test_eig(a) + check_eig(mx_w, mx_v, a.asnumpy()) + # check imperative once again + mx_w, mx_v = test_eig(a) + check_eig(mx_w, mx_v, a.asnumpy()) + + +@with_seed() +@use_np +def test_np_linalg_eigh(): + class TestEigh(HybridBlock): + def __init__(self, UPLO): + super(TestEigh, self).__init__() + self._UPLO = UPLO + + def hybrid_forward(self, F, a): + return F.np.linalg.eigh(a, UPLO=self._UPLO) + + def check_eigh(w, v, a_np, UPLO): + try: + w_expected, v_expected = _np.linalg.eigh(a_np, UPLO) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print(e) + else: + assert w.shape == w_expected.shape + assert v.shape == v_expected.shape + # check eigenvalues. + assert_almost_equal(w, w_expected, rtol=rtol, atol=atol) + # check eigenvectors. + w_shape, v_shape, a_sym_np = get_sym_matrix_nd(a_np, UPLO) + w_np = w.asnumpy() + v_np = v.asnumpy() + if 0 not in a_np.shape: + w_np = w_np.reshape(w_shape) + v_np = v_np.reshape(v_shape) + a_sym_np = a_sym_np.reshape(v_shape) + for i in range(w_shape[0]): + for j in range(w_shape[1]): + assert_almost_equal(_np.dot(a_sym_np[i], v_np[i][:, j]), w_np[i][j] * v_np[i][:, j], rtol=rtol, atol=atol) + + def get_sym_matrix_nd(a_np, UPLO): + a_res_np = a_np + shape = a_np.shape + if 0 not in a_np.shape: + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + nrow, ncol = shape[-2], shape[-1] + a_np = a_np.reshape(n, nrow, ncol) + a_res_np = a_np + for idx in range(n): + for i in range(nrow): + for j in range(ncol): + if ((UPLO == 'L' and i < j) or (UPLO == 'U' and i > j)): + a_res_np[idx][i][j] = a_np[idx][j][i] + return (n, nrow), (n, nrow, ncol), a_res_np.reshape(shape) + else : + return (0, 0), (0, 0, 0), a_res_np.reshape(shape) + + def new_matrix_from_sym_matrix_nd(sym_a, UPLO): + shape = sym_a.shape + if 0 in shape: + return sym_a + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a = sym_a.reshape(n, shape[-2], shape[-1]) + for idx in range(n): + for i in range(shape[-2]): + for j in range(shape[-1]): + if ((UPLO == 'U' and i > j) or (UPLO == 'L' and i < j)): + a[idx][i][j] = _np.random.uniform(-10., 10.) + return a.reshape(shape) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (5, 5), + (1, 0, 0), + (0, 4, 4), + (1, 4, 4), + (2, 4, 4), + (5, 5, 5), + (1, 1, 4, 4), + (2, 3, 4, 4) + ] + dtypes = ['float32', 'float64', 'int32', 'int64'] + UPLOs = ['L', 'U'] + for hybridize in [True, False]: + for shape, dtype, UPLO in itertools.product(shapes, dtypes, UPLOs): + rtol = 1e-2 if dtype == 'float32' else 1e-3 + atol = 1e-4 if dtype == 'float32' else 1e-5 + test_eigh = TestEigh(UPLO) + if hybridize: + test_eigh.hybridize() + if 0 in shape: + a_np = _np.ones(shape) + else: + if dtype == 'int32' or dtype == 'int64': + n = int(_np.prod(shape[:-2])) if len(shape) > 2 else 1 + a_np = _np.array([_np.diag(_np.random.randint(1, 10, size=shape[-1])) for i in range(n)], dtype=dtype).reshape(shape) + else: + a_np = new_sym_matrix_with_real_eigvals_nd(shape) + a_np = new_matrix_from_sym_matrix_nd(a_np, UPLO) + a = np.array(a_np, dtype=dtype) + # check eigh validity + w, v = test_eigh(a) + check_eigh(w, v, a.asnumpy(), UPLO) + # check imperative once again + w, v = test_eigh(a) + check_eigh(w, v, a.asnumpy(), UPLO) + + @with_seed() @use_np def test_np_linalg_det():