Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add matrix inversion operator in linalg #14963

Merged
merged 21 commits into from
May 20, 2019
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/symbol/linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
makediag
extracttrian
maketrian
inverse
```

## API Reference
Expand Down
26 changes: 23 additions & 3 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,29 @@

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork) { \
int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER4(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER5(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
Expand All @@ -69,4 +83,10 @@

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)

MXNET_LAPACK_CWRAPPER4(sgetrf, float)
MXNET_LAPACK_CWRAPPER4(dgetrf, double)

MXNET_LAPACK_CWRAPPER5(sgetri, float)
MXNET_LAPACK_CWRAPPER5(dgetri, double)
#endif // MSHADOW_USE_MKL == 0
100 changes: 93 additions & 7 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ extern "C" {

MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
#else
#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
#endif

MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
MXNET_LAPACK_FSIG_GETRF(dgetrf, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
int *lwork, int *info);
#else
#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
int *lwork, int *info);
#endif

MXNET_LAPACK_FSIG_GETRI(sgetri, float)
MXNET_LAPACK_FSIG_GETRI(dgetri, double)
}

#endif // MSHADOW_USE_MKL == 0
Expand Down Expand Up @@ -171,8 +195,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// MXNET_LAPACK-signature and have to be wrapped.
#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
} \
Expand All @@ -184,8 +208,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
} \
Expand Down Expand Up @@ -215,6 +239,21 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(s, float)
MXNET_LAPACK_CWRAP_SYEVD(d, double)

#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf

#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
} \
*work = 0; \
return 0; \
}
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#elif MXNET_USE_LAPACK

#define MXNET_LAPACK_ROW_MAJOR 101
Expand Down Expand Up @@ -322,6 +361,38 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)

// Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
// (MXNet) for performance.
#define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GETRF(s, float)
MXNET_LAPACK_CWRAP_GETRF(d, double)

#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#else


Expand All @@ -335,12 +406,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork);
int lda, dtype* tau, dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork);
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork);

#define MXNET_LAPACK_CWRAPPER4(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv);

#define MXNET_LAPACK_CWRAPPER5(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork);

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
Expand All @@ -359,9 +438,16 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)

MXNET_LAPACK_CWRAPPER4(sgetrf, float)
MXNET_LAPACK_CWRAPPER4(dgetrf, double)

MXNET_LAPACK_CWRAPPER5(sgetri, float)
MXNET_LAPACK_CWRAPPER5(dgetri, double)
#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
#undef MXNET_LAPACK_CWRAPPER4
#undef MXNET_LAPACK_UNAVAILABLE
#endif

Expand Down
49 changes: 49 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,55 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
Stream<xpu> *s = 0);

//////////////////////////////// GETRF ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getrf(A), so A is input and output parameter.

template<typename xpu, typename DType>
void linalg_getrf(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

//////////////////////////////// GETRI ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK function "getri". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getri(A), so A is input and output parameter.

template<typename xpu, typename DType>
void linalg_getri(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

// This function determines the amount of workspace needed for linalg_getri to operate
// on a batch of matrices which is returned as number of elements of type DType.
template<typename xpu, typename DType>
int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
Stream<xpu> *s = 0);

//////////////////////////////// INVERSE ////////////////////////////////////////////

// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
// Note that A = inverse(B)
template<typename xpu, typename DType>
void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

#include "linalg_impl.h"

#endif // MXNET_OPERATOR_LINALG_H_
Loading