Skip to content

Commit

Permalink
Add matrix determinant operator in linalg (apache#15007)
Browse files Browse the repository at this point in the history
* add backbone

* cpu forward det

* refactor for gpu forward det

* fix

* register gpu det forward

* add gpu det backward

* register gpu det backward

* fix

* add logdet slogdet backward

* stop grad for zero det

* fix

* fix

* reduce grad transfer

* fix docs

* update comments

* fix docs

* fix lint

* add test

* update docs

* add operator

* update test

* trigger CI

* remove slash

* update comments and docs

* update det helper function

* update operator check

* remove logdet

* add no grad when det = 0

* update comments and docs

* remove remaining logdet
  • Loading branch information
arcadiaphy authored and Ubuntu committed Aug 20, 2019
1 parent 90e3d3f commit f6fc011
Show file tree
Hide file tree
Showing 9 changed files with 679 additions and 137 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/symbol/linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
extracttrian
maketrian
inverse
det
slogdet
```

## API Reference
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/contrib/amp/lists/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@
'_linalg_maketrian',
'_linalg_extracttrian',
'_linalg_inverse',
'_linalg_det',
'_linalg_slogdet',
'linalg_syrk',
'linalg_potrf',
'linalg_potri',
Expand All @@ -446,6 +448,8 @@
'linalg_maketrian',
'linalg_extracttrian',
'linalg_inverse',
'linalg_det',
'linalg_slogdet',
'_NDArray',
'_Native',
'_contrib_count_sketch',
Expand Down
50 changes: 34 additions & 16 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,50 +195,68 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,

// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getrf(A), so A is input and output parameter.

// Note:
// - A is input and output parameter (overwritten by LU)
// - Param check_singular is only useful in cpu version. If check_singular is false,
// don't throw error when A is non-invertible matrix.
template<typename xpu, typename DType>
void linalg_getrf(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& work,
const Tensor<xpu, 1, int>& pivot,
bool check_singular,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 1, DType>& work,
const Tensor<xpu, 2, int>& pivot,
bool check_singular,
Stream<xpu> *s = 0);

//////////////////////////////// GETRI ////////////////////////////////////////////

// CPU/GPU-versions of LAPACK function "getri". Please refer to the
// LAPACK documentation for further details.
// Note that this is A = getri(A), so A is input and output parameter.

// Note:
// - pivot and LU is the output of getrf(A)
// - LU is also the output parameter (overwritten by inverse(A))
template<typename xpu, typename DType>
void linalg_getri(const Tensor<xpu, 2, DType>& A,
void linalg_getri(const Tensor<xpu, 2, DType>& LU,
const Tensor<xpu, 1, int>& pivot, \
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

// Note that this function only implements GPU version with "getriBatched" in cuBLAS.
// Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix
// inverse is stored in A.
template<typename xpu, typename DType>
void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
const Tensor<xpu, 3, DType>& LU,
const Tensor<xpu, 2, int>& pivot,
Stream<xpu> *s = 0);

// This function determines the amount of workspace needed for linalg_getri to operate
// on a batch of matrices which is returned as number of elements of type DType.
template<typename xpu, typename DType>
int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
Stream<xpu> *s = 0);

//////////////////////////////// INVERSE ////////////////////////////////////////////

// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri"
// Note that A = inverse(B)
template<typename xpu, typename DType>
void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);
const mxnet::OpContext& ctx);

//////////////////////////////// DET ////////////////////////////////////////////

// CPU/GPU-versions of helper functions used in matrix determinant operators

// Helper function in determinant backward computation: compute matrix inverse
// from LU and pivot using temp workspace, the result is stored back to LU
template<typename xpu, typename DType>
void linalg_batch_det_backward_helper(const Tensor<xpu, 3, DType>& LU,
const Tensor<xpu, 2, int>& pivot,
const Tensor<xpu, 1, DType>& det,
const Tensor<xpu, 3, DType>& temp,
const DType zero_det,
const mxnet::OpContext& ctx);

#include "linalg_impl.h"

Expand Down
Loading

0 comments on commit f6fc011

Please sign in to comment.