Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix bug in test
Browse files Browse the repository at this point in the history
fix bug in MXNET_LAPACK_FSIG_GESV
fix bug
fix format
fix undefined #gesv
fix format
  • Loading branch information
Ubuntu committed Nov 27, 2019
1 parent a98cefc commit 983bb2e
Show file tree
Hide file tree
Showing 11 changed files with 991 additions and 5 deletions.
57 changes: 56 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -352,3 +352,58 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)


def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, x, of the well-determined, i.e., full rank, linear matrix equation ax = b.
Parameters :
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns :
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to b.
Raises :
LinAlgError
If a is singular or not square.
Notes
Broadcasting rules apply, see the numpy.linalg documentation for details.
The solutions are computed using LAPACK routine _gesv
a must be square and of full-rank, i.e., all rows (or, equivalently, columns)
must be linearly independent; if either is not true, use lstsq for the least-squares
best "solution" of the system/equation.
References
[R41] G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, pg. 22.
Examples
Solve the system of equations 3 * x0 + x1 = 9 and x0 + 2 * x1 = 8:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([ 2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
56 changes: 55 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -370,3 +370,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _mx_nd_np.linalg.slogdet(a)

def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, x, of the well-determined, i.e., full rank, linear matrix equation ax = b.
Parameters :
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns :
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to b.
Raises :
LinAlgError
If a is singular or not square.
Notes
Broadcasting rules apply, see the numpy.linalg documentation for details.
The solutions are computed using LAPACK routine _gesv
a must be square and of full-rank, i.e., all rows (or, equivalently, columns)
must be linearly independent; if either is not true, use lstsq for the least-squares
best "solution" of the system/equation.
References
[R41] G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, pg. 22.
Examples
Solve the system of equations 3 * x0 + x1 = 9 and x0 + 2 * x1 = 8:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([ 2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _mx_nd_np.linalg.solve(a, b)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'linalg.solve',
'shape',
'trace',
'tril',
Expand Down
56 changes: 55 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -339,3 +339,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)

def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, x, of the well-determined, i.e., full rank, linear matrix equation ax = b.
Parameters :
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns :
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to b.
Raises :
LinAlgError
If a is singular or not square.
Notes
Broadcasting rules apply, see the numpy.linalg documentation for details.
The solutions are computed using LAPACK routine _gesv
a must be square and of full-rank, i.e., all rows (or, equivalently, columns)
must be linearly independent; if either is not true, use lstsq for the least-squares
best "solution" of the system/equation.
References
[R41] G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, pg. 22.
Examples
Solve the system of equations 3 * x0 + x1 = 9 and x0 + 2 * x1 = 8:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([ 2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
10 changes: 10 additions & 0 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@
return 1; \
}

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
Expand Down Expand Up @@ -101,4 +108,7 @@
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#endif // MSHADOW_USE_MKL == 0
39 changes: 37 additions & 2 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ extern "C" {

MXNET_LAPACK_FSIG_GETRI(sgetri, float)
MXNET_LAPACK_FSIG_GETRI(dgetri, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
int func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#else
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
void func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#endif

MXNET_LAPACK_FSIG_GESV(sgesv, float)
MXNET_LAPACK_FSIG_GESV(dgesv, double)
}

#endif // MSHADOW_USE_MKL == 0
Expand Down Expand Up @@ -197,6 +210,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_dpotri LAPACKE_dpotri
#define mxnet_lapack_sposv LAPACKE_sposv
#define mxnet_lapack_dposv LAPACKE_dposv
#define MXNET_LAPACK_dgesv LAPACKE_dgesv
#define MXNET_LAPACK_sgesv LAPACKE_sgesv

// The following functions differ in signature from the
// MXNET_LAPACK-signature and have to be wrapped.
Expand Down Expand Up @@ -440,9 +455,23 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#else

#define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \
int n, int nrhs, dtype *a, int lda, \
int *ipiv, dtype *b, int ldb) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GESV(s, float)
MXNET_LAPACK_CWRAP_GESV(d, double)

#else

#define MXNET_LAPACK_ROW_MAJOR 101
#define MXNET_LAPACK_COL_MAJOR 102
Expand Down Expand Up @@ -473,6 +502,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
int ldut, dtype* s, dtype* v, int ldv, \
dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb); \

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
Expand Down Expand Up @@ -501,6 +533,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
Expand Down
Loading

0 comments on commit 983bb2e

Please sign in to comment.