Skip to content

Commit

Permalink
Add matrix inversion operator in linalg (apache#14963)
Browse files Browse the repository at this point in the history
* add inverse cpu

* add comment

* add inverse backward cpu

* add inverse gpu

* able to compile

* fix

* fix

* guard for lower version cuda

* update docs

* update docs

* fix misaligned memory

* add test

* fix lint

* fix android

* fix indent

* change transfer gradient

* fix

* refactor test

* delete unnecessary copy

* trigger CI

* fix test
  • Loading branch information
arcadiaphy authored and haohuw committed Jun 23, 2019
1 parent 27f3d92 commit 8cb5aaf
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/api/python/symbol/linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
makediag
extracttrian
maketrian
inverse
```

## API Reference
Expand Down
26 changes: 23 additions & 3 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,29 @@

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork) { \
int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER4(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER5(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
Expand All @@ -69,4 +83,10 @@

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)

MXNET_LAPACK_CWRAPPER4(sgetrf, float)
MXNET_LAPACK_CWRAPPER4(dgetrf, double)

MXNET_LAPACK_CWRAPPER5(sgetri, float)
MXNET_LAPACK_CWRAPPER5(dgetri, double)
#endif // MSHADOW_USE_MKL == 0
100 changes: 93 additions & 7 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ extern "C" {

MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
#else
#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
#endif

MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
MXNET_LAPACK_FSIG_GETRF(dgetrf, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
int *lwork, int *info);
#else
#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
int *lwork, int *info);
#endif

MXNET_LAPACK_FSIG_GETRI(sgetri, float)
MXNET_LAPACK_FSIG_GETRI(dgetri, double)
}

#endif // MSHADOW_USE_MKL == 0
Expand Down Expand Up @@ -171,8 +195,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// MXNET_LAPACK-signature and have to be wrapped.
#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
} \
Expand All @@ -184,8 +208,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype* tau, \
dtype* work, int lwork) { \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
} \
Expand Down Expand Up @@ -215,6 +239,21 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(s, float)
MXNET_LAPACK_CWRAP_SYEVD(d, double)

#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf

#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
} \
*work = 0; \
return 0; \
}
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#elif MXNET_USE_LAPACK

#define MXNET_LAPACK_ROW_MAJOR 101
Expand Down Expand Up @@ -322,6 +361,38 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)

// Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
// (MXNet) for performance.
#define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GETRF(s, float)
MXNET_LAPACK_CWRAP_GETRF(d, double)

#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#else


Expand All @@ -335,12 +406,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork);
int lda, dtype* tau, dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork);
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork);

#define MXNET_LAPACK_CWRAPPER4(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
dtype *a, int lda, int *ipiv);

#define MXNET_LAPACK_CWRAPPER5(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork);

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
Expand All @@ -359,9 +438,16 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)

MXNET_LAPACK_CWRAPPER4(sgetrf, float)
MXNET_LAPACK_CWRAPPER4(dgetrf, double)

MXNET_LAPACK_CWRAPPER5(sgetri, float)
MXNET_LAPACK_CWRAPPER5(dgetri, double)
#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
#undef MXNET_LAPACK_CWRAPPER4
#undef MXNET_LAPACK_UNAVAILABLE
#endif

Expand Down
49 changes: 49 additions & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,55 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
Stream<xpu> *s = 0);

//////////////////////////////// GETRF ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getrf(A), so A is input and output parameter.

template<typename xpu, typename DType>
void linalg_getrf(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

//////////////////////////////// GETRI ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK function "getri". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getri(A), so A is input and output parameter.

template<typename xpu, typename DType>
void linalg_getri(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

// This function determines the amount of workspace needed for linalg_getri to operate
// on a batch of matrices which is returned as number of elements of type DType.
template<typename xpu, typename DType>
int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
Stream<xpu> *s = 0);

//////////////////////////////// INVERSE ////////////////////////////////////////////

// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
// Note that A = inverse(B)
template<typename xpu, typename DType>
void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

#include "linalg_impl.h"

#endif // MXNET_OPERATOR_LINALG_H_
Loading

0 comments on commit 8cb5aaf

Please sign in to comment.