From f34e0792eab42e3da3d1318dc385d1bf2dcde154 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 17 May 2019 01:58:49 +0800 Subject: [PATCH 01/30] add backbone --- src/operator/tensor/la_op-inl.h | 9 +++++++++ src/operator/tensor/la_op.cc | 36 +++++++++++++++++++++++++++++++++ src/operator/tensor/la_op.h | 21 +++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 4dead87b3dce..0552b55ea081 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -469,6 +469,15 @@ struct inverse { } }; +// A = det(B). +struct det { + template + static void op(const Tensor& B, const Tensor& A, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + } +}; + // Backward operators (always using batch processing) struct gemm_backward { diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 2fa1fd3a1cb2..3190883416c1 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -939,5 +939,41 @@ NNVM_REGISTER_OP(_backward_linalg_inverse) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_det) +.add_alias("linalg_det") +.describe(R"code(Compute the determinant of a matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a square matrix. We compute: + + *out* = *det(A)* + +If *n>2*, *det* is performed separately on the trailing two dimensions +for all inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix inversion + A = [[1., 4.], [2., 3.]] + det(A) = [-5.] + + // Batch matrix inversion + A = [[[1., 4.], [2., 3.]], + [[1., 3.], [2., 4.]]] + det(A) = [-5., -2.] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", DetShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FCompute", LaOpForward) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 5b0c7e3562dc..ee75c459de30 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -413,6 +413,27 @@ inline bool InverseShape(const nnvm::NodeAttrs& attrs, return true; } +// Shape inference function for linalg_det +// Inputs: A. Outputs: det(A) +inline bool DetShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const mxnet::TShape& in = (*in_attrs)[0]; + const int ndim(in.ndim()); + CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; + CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; + mxnet::TShape out; + if (ndim == 2) { + out = mxnet::TShape(1, 1); + } else { + out = mxnet::TShape(in.begin(), in.end() - 2); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); + return true; +} + // Shape inference function for linalg_syevd // Inputs: A. Outputs: U, L inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, From ce29d9dd6b973fc0f942c047621da31d8c6b29a0 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sat, 18 May 2019 14:17:21 +0800 Subject: [PATCH 02/30] cpu forward det --- src/operator/linalg.h | 4 ++-- src/operator/linalg_impl.h | 26 +++++++++++++++------ src/operator/tensor/la_op-inl.h | 36 +++++++++++++++++++++++++++-- src/operator/tensor/la_op.cc | 10 ++++---- src/operator/tensor/la_op.h | 41 +++++++++++++++++++++++++++++++-- 5 files changed, 100 insertions(+), 17 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index ee713e5548c0..fbef929cc300 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -197,9 +197,9 @@ int linalg_syevd_workspace_query(const Tensor& A, // LAPACK documentation for further details. // Note that this is A = getrf(A), so A is input and output parameter. -template +template void linalg_getrf(const Tensor& A, - const Tensor& work, + const Tensor& work, Stream *s = 0); template diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 718e3f9c5aa0..5f3480334acd 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1240,26 +1240,38 @@ LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double) // The input of this function should be col-major for performance. // Tensor work holds space for ipiv in getrf -#define LINALG_CPU_GETRF(fname, DType) \ +#define LINALG_CPU_GETRF1(fname, DType) \ template<> inline \ -void linalg_getrf(const Tensor& A, \ - const Tensor& work, \ - Stream *s) { \ +void linalg_getrf(const Tensor& A, \ + const Tensor& work, \ + Stream *s) { \ int *ipiv = reinterpret_cast(work.dptr_); \ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \ A.dptr_, A.stride_, ipiv)); \ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ } -LINALG_CPU_GETRF(sgetrf, float) -LINALG_CPU_GETRF(dgetrf, double) +#define LINALG_CPU_GETRF2(fname, DType) \ +template<> inline \ +void linalg_getrf(const Tensor& A, \ + const Tensor& work, \ + Stream *s) { \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \ + A.dptr_, A.stride_, work.dptr_)); \ + CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +} + +LINALG_CPU_GETRF1(sgetrf, float) +LINALG_CPU_GETRF1(dgetrf, double) +LINALG_CPU_GETRF2(sgetrf, float) +LINALG_CPU_GETRF2(dgetrf, double) #ifdef __CUDACC__ // "getrfBatched" and "getriBatched" in cuBLAS must have DType *matrices[] as input // to store the pointers of each batch matrix. This kernel is used to build the // pointer array. -struct set_matrix : public mxnet::op::mxnet_op::tunable { +struct set_matrix { template MSHADOW_XINLINE static void Map(int i, DType **p, DType *m, int step) { p[i] = m + i * step; diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 0552b55ea081..c7dc165a30ab 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -469,12 +469,44 @@ struct inverse { } }; -// A = det(B). +struct SignedLogDet { + template + MSHADOW_XINLINE static void Map(int i, int N, int* pivot, + DType *LU, DType* sign, DType *logdet) { + int changes(0); + DType diag_sign(1); + DType diag_logsum(0); + int *pivot_mat = pivot + i * N; + DType *LU_mat = LU + i * N * N; + for ( int j = 0; j < N; ++j ) { + changes += (pivot_mat[j] == (j + 1)); + DType diag = LU_mat[j * (N + 1)]; + diag_sign *= ((DType(0) < diag) - (diag < DType(0))); + diag_logsum += std::log(std::abs(diag)); + } + sign[i] = (changes % 2 == 1 ? DType(-1) : DType(1)) * diag_sign; + logdet[i] = diag_logsum; + } +}; +// det = det(A), LU and pivot store the LU decomposition output which will be +// used in computing gradient struct det { template - static void op(const Tensor& B, const Tensor& A, + static void op(const Tensor& A, const Tensor& det, + const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream *s = ctx.get_stream(); + Tensor sign = ctx.requested[0] + .get_space_typed(det.shape_, s); + Copy(LU, A, s); + for(index_t i = 0; i < A.size(0); ++i) { + linalg_getrf(LU[i], pivot[i], s); + } + using namespace mxnet_op; + using namespace mshadow::expr; + Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, + LU.dptr_, sign.dptr_, det.dptr_); + const_cast&>(det) = sign * F(det); } }; diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 3190883416c1..bf56d3f17f9a 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -965,14 +965,16 @@ Examples:: det(A) = [-5., -2.] )code" ADD_FILELINE) .set_num_inputs(1) -.set_num_outputs(1) +.set_num_outputs(3) .set_attr("FListInputNames", [](const NodeAttrs& attrs) - { return std::vector{"A"}; } ) + { return std::vector{"A", "LU", "pivot"}; }) +.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { + return 1; }) .set_attr("FInferShape", DetShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferType", DetType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", LaOpForward) +.set_attr("FCompute", LaOpDetForward) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); } // namespace op diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index ee75c459de30..4f9c04efef21 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -419,7 +419,7 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 1); - CHECK_EQ(out_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 3); const mxnet::TShape& in = (*in_attrs)[0]; const int ndim(in.ndim()); CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; @@ -430,7 +430,24 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, } else { out = mxnet::TShape(in.begin(), in.end() - 2); } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); /* det */ + SHAPE_ASSIGN_CHECK(*out_attrs, 1, in); /* LU */ + SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ + return true; +} + +inline bool DetType(const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + using namespace mshadow; + CHECK_EQ(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "Input must have specified type"; + + out_type->clear(); + out_type->push_back(dtype); + out_type->push_back(dtype); + out_type->push_back(mshadow::kInt32); return true; } @@ -774,6 +791,26 @@ void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs, }); } +// (A) => (det, LU, pivot) +template +void LaOpDetForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 3); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + outputs[2].FlatToKD(s), ctx, attrs); + }); +} + + } // namespace op } // namespace mxnet From 744d0ee5c0e5612980c4a10bc94428450f1e7a7a Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 13:58:11 +0800 Subject: [PATCH 03/30] refactor for gpu forward det --- src/operator/linalg.h | 35 +++--- src/operator/linalg_impl.h | 192 ++++++++++++++++---------------- src/operator/tensor/la_op-inl.h | 55 +++++++-- src/operator/tensor/la_op.cc | 93 +++++++++++++++- src/operator/tensor/la_op.h | 74 +++++++++--- 5 files changed, 303 insertions(+), 146 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index fbef929cc300..25f08d6f1d34 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -195,40 +195,42 @@ int linalg_syevd_workspace_query(const Tensor& A, // CPU/GPU-versions of LAPACK function "getrf". Please refer to the // LAPACK documentation for further details. -// Note that this is A = getrf(A), so A is input and output parameter. +// Note: +// - A is input and output parameter (overwritten by LU) +// - Param check_singular is only useful in cpu version. +// If check_singular is false, don't throw error when A is in-convertible matrix. -template +template void linalg_getrf(const Tensor& A, - const Tensor& work, + const Tensor& pivot, + bool check_singular, Stream *s = 0); template void linalg_batch_getrf(const Tensor& A, - const Tensor& work, + const Tensor& pivot, + bool check_singular, Stream *s = 0); //////////////////////////////// GETRI //////////////////////////////////////////// // CPU/GPU-versions of LAPACK function "getri". Please refer to the // LAPACK documentation for further details. -// Note that this is A = getri(A), so A is input and output parameter. +// Note: +// - pivot and LU is the output of getrf(A) +// - LU is also the output parameter (overwritten by inverse(A)) template -void linalg_getri(const Tensor& A, +void linalg_getri(const Tensor& LU, + const Tensor& pivot, \ const Tensor& work, Stream *s = 0); template void linalg_batch_getri(const Tensor& A, - const Tensor& B, - const Tensor& work, - Stream *s = 0); - -// This function determines the amount of workspace needed for linalg_getri to operate -// on a batch of matrices which is returned as number of elements of type DType. -template -int linalg_getri_workspace_query(const Tensor& A, - Stream *s = 0); + const Tensor& LU, + const Tensor& pivot, + Stream *s = 0); //////////////////////////////// INVERSE //////////////////////////////////////////// @@ -237,8 +239,7 @@ int linalg_getri_workspace_query(const Tensor& A, template void linalg_batch_inverse(const Tensor& A, const Tensor& B, - const Tensor& work, - Stream *s = 0); + const mxnet::OpContext& ctx); #include "linalg_impl.h" diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 5f3480334acd..4cd9ca46f1be 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1240,31 +1240,36 @@ LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double) // The input of this function should be col-major for performance. // Tensor work holds space for ipiv in getrf -#define LINALG_CPU_GETRF1(fname, DType) \ +#define LINALG_CPU_GETRF(fname, DType) \ template<> inline \ -void linalg_getrf(const Tensor& A, \ - const Tensor& work, \ - Stream *s) { \ - int *ipiv = reinterpret_cast(work.dptr_); \ +void linalg_getrf(const Tensor& A, \ + const Tensor& pivot, \ + bool check_singular, Stream *s) { \ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \ - A.dptr_, A.stride_, ipiv)); \ - CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ + A.dptr_, A.stride_, pivot.dptr_)); \ + CHECK_GE(ret, 0) << #fname << " failed in lapack on cpu."; \ + if (check_singular) { \ + CHECK_EQ(ret, 0) << "the input matrix is in-convertible"; \ + } \ } -#define LINALG_CPU_GETRF2(fname, DType) \ +LINALG_CPU_GETRF(sgetrf, float) +LINALG_CPU_GETRF(dgetrf, double) + + +#define LINALG_CPU_BATCH_GETRF(fname, DType) \ template<> inline \ -void linalg_getrf(const Tensor& A, \ - const Tensor& work, \ - Stream *s) { \ - int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \ - A.dptr_, A.stride_, work.dptr_)); \ - CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +void linalg_batch_getrf(const Tensor& A, \ + const Tensor& pivot, \ + bool check_singular, \ + Stream *s) { \ + for(index_t i = 0; i < A.size(0); ++i) { \ + linalg_getrf(A[i], pivot[i], check_singular); \ + } \ } -LINALG_CPU_GETRF1(sgetrf, float) -LINALG_CPU_GETRF1(dgetrf, double) -LINALG_CPU_GETRF2(sgetrf, float) -LINALG_CPU_GETRF2(dgetrf, double) +LINALG_CPU_BATCH_GETRF(sgetrf, float) +LINALG_CPU_BATCH_GETRF(dgetrf, double) #ifdef __CUDACC__ @@ -1289,23 +1294,22 @@ struct set_matrix { #define LINALG_GPU_BATCH_GETRF(fname, DType) \ template<> inline \ void linalg_batch_getrf(const Tensor& A, \ - const Tensor& work, \ + const Tensor& pivot, \ + bool check_singular, \ Stream *s) { \ using namespace mxnet; \ using namespace mxnet::op::mxnet_op; \ CHECK_NOTNULL(s); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ DType **A_ptr = static_cast(A_ptr_buf.dptr); \ - const Tensor temp(work.dptr_, A.shape_, s); \ - int *pivot = reinterpret_cast(temp.dptr_ + temp.shape_.Size()); \ - int *info = pivot + A.size(0) * A.size(1); \ - Copy(temp, A, s); \ - Kernel::Launch(s, temp.size(0), \ - A_ptr, temp.dptr_, \ - temp.size(1) * temp.size(2)); \ + Kernel::Launch(s, A.size(0), \ + A_ptr, A.dptr_, \ + A.size(1) * A.size(2)); \ CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), \ - A.size(1), A_ptr, A.size(2), pivot, \ - info, A.size(0))) \ + A.size(1), A_ptr, A.size(2), pivot.dptr_, \ + static_cast(info.dptr), A.size(0))) \ + Storage::Get()->Free(info); \ Storage::Get()->Free(A_ptr_buf); \ } @@ -1314,7 +1318,8 @@ void linalg_batch_getrf(const Tensor& A, \ #define LINALG_GPU_BATCH_GETRF(fname, DType) \ template<> inline \ void linalg_batch_getrf(const Tensor& A, \ - const Tensor& work, \ + const Tensor& pivot, \ + bool check_singular, \ Stream *s) { \ LOG(FATAL) << "batched getrf requires CUDA version >= 8.0!"; \ } @@ -1331,38 +1336,36 @@ LINALG_GPU_BATCH_GETRF(DgetrfBatched, double) // CPU/GPU-versions of LAPACK function "getri" // The input of this function should be col-major for performance. -// Tensor work holds space for ipiv, work in getri #define LINALG_CPU_GETRI(fname, DType) \ template<> inline \ -void linalg_getri(const Tensor& A, \ +void linalg_getri(const Tensor& LU, \ + const Tensor& pivot, \ const Tensor& work, \ Stream *s) { \ - DType wkopt; \ - MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ - A.stride_, nullptr, &wkopt, -1); \ - int lwork(static_cast(wkopt)); \ - int *ipiv = reinterpret_cast(work.dptr_); \ - DType *pwork = reinterpret_cast(ipiv + A.size(0)); \ - int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ - A.stride_, ipiv, pwork, lwork)); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, LU.size(0), LU.dptr_, \ + LU.stride_, pivot.dptr_, work.dptr_, work.size(0))); \ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ } LINALG_CPU_GETRI(sgetri, float) LINALG_CPU_GETRI(dgetri, double) -// Query workspace for the whole batch of matrices.For cpu version, the workspace -// is re-used, so space for only one matrix is enough. +template +int linalg_getri_workspace_query(const Tensor& A, \ + Stream *s) { + LOG(FATAL) << "it only takes float or double Tensor"; +} + +// Query workspace for "getri" #define LINALG_CPU_GETRI_WORKSPACE_QUERY(func, DType) \ template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ +int linalg_getri_workspace_query(const Tensor& A, \ Stream *s) { \ - const Tensor& matrix = A[0]; \ DType lwork(0); \ - MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, matrix.size(0), matrix.dptr_, \ - matrix.stride_, nullptr, &lwork, -1); \ - int ipiv = (sizeof(int) * matrix.size(0) + sizeof(DType) - 1) / sizeof(DType); \ - return ipiv + static_cast(lwork); \ + MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ + A.stride_, nullptr, &lwork, -1); \ + return lwork; \ } + LINALG_CPU_GETRI_WORKSPACE_QUERY(sgetri, float) LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double) @@ -1379,41 +1382,30 @@ LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double) #define LINALG_GPU_BATCH_GETRI(fname, DType) \ template<> inline \ void linalg_batch_getri(const Tensor& A, \ - const Tensor& B, \ - const Tensor& work, \ + const Tensor& LU, \ + const Tensor& pivot, \ Stream *s) { \ using namespace mxnet; \ using namespace mxnet::op::mxnet_op; \ CHECK_NOTNULL(s); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ DType **A_ptr = static_cast(A_ptr_buf.dptr); \ - Storage::Handle B_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ - DType **B_ptr = static_cast(B_ptr_buf.dptr); \ - Tensor temp(work.dptr_, A.shape_, s); \ - int *pivot = reinterpret_cast(temp.dptr_ + temp.shape_.Size()); \ - int *info = pivot + A.size(0) * A.size(1); \ + Storage::Handle LU_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ + DType **LU_ptr = static_cast(LU_ptr_buf.dptr); \ Kernel::Launch(s, A.size(0), \ A_ptr, A.dptr_, \ A.size(1) * A.size(2)); \ - Kernel::Launch(s, temp.size(0), \ - B_ptr, temp.dptr_, \ - temp.size(1) * temp.size(2)); \ - CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), \ - A.size(1), const_cast(B_ptr), \ - B.size(2), const_cast(pivot), \ - A_ptr, A.size(2), info, A.size(0))) \ + Kernel::Launch(s, LU.size(0), \ + LU_ptr, LU.dptr_, \ + LU.size(1) * LU.size(2)); \ + CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), A.size(1), \ + const_cast(LU_ptr), B.size(2), \ + const_cast(pivot.dptr_), A_ptr, A.size(2), \ + static_cast(info.dptr), A.size(0))) \ + Storage::Get()->Free(info); \ Storage::Get()->Free(A_ptr_buf); \ - Storage::Get()->Free(B_ptr_buf); \ -} - -#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \ -template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ - Stream *s) { \ - int pivot_size = sizeof(int) * A.size(0) * A.size(1); \ - int info_size = sizeof(int) * A.size(0); \ - int matrix_size = sizeof(DType) * A.shape_.Size(); \ - return (pivot_size + info_size + matrix_size + sizeof(DType) - 1) / sizeof(DType); \ + Storage::Get()->Free(LU_ptr_buf); \ } #else @@ -1421,27 +1413,17 @@ int linalg_getri_workspace_query(const Tensor& A, \ #define LINALG_GPU_BATCH_GETRI(fname, DType) \ template<> inline \ void linalg_batch_getri(const Tensor& A, \ - const Tensor& B, \ - const Tensor& work, \ + const Tensor& LU, \ + const Tensor& pivot, \ Stream *s) { \ LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \ } -#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \ -template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ - Stream *s) { \ - LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \ -} - #endif // CUDA_VERSION >= 8000 LINALG_GPU_BATCH_GETRI(SgetriBatched, float) LINALG_GPU_BATCH_GETRI(DgetriBatched, double) -LINALG_GPU_GETRI_WORKSPACE_QUERY(SgetriBatched, float) -LINALG_GPU_GETRI_WORKSPACE_QUERY(DgetriBatched, double) - #endif // __CUDACC__ //////////////////////////////// INVERSE //////////////////////////////////////////// @@ -1453,17 +1435,27 @@ LINALG_GPU_GETRI_WORKSPACE_QUERY(DgetriBatched, double) template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int lwork(linalg_getri_workspace_query(A[0], s)); \ + int workspace_size = (sizeof(int) * A.size(1) + sizeof(DType) * lwork + \ + sizeof(DType) - 1) / sizeof(DType); \ + Tensor workspace = ctx.requested[0].\ + get_space_typed(Shape1(workspace_size), s); \ + const Tensor pivot(reinterpret_cast(workspace.dptr_), \ + Shape1(A.size(1))); \ + const Tensor work(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + Shape1(lwork)); \ if (A.dptr_ != B.dptr_) Copy(A, B, s); \ for (index_t i = 0; i < A.size(0); ++i) { \ - linalg_getrf(A[i], work, s); \ - linalg_getri(A[i], work, s); \ + linalg_getrf(A[i], pivot, true, s); \ + linalg_getri(A[i], pivot, work, s); \ } \ } LINALG_CPU_BATCH_INVERSE(cpu, float) LINALG_CPU_BATCH_INVERSE(cpu, double) + #ifdef __CUDACC__ // GETRF and GETRI only available with cuda8 or higher. @@ -1473,10 +1465,21 @@ LINALG_CPU_BATCH_INVERSE(cpu, double) template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ - linalg_batch_getrf(B, work, s); \ - linalg_batch_getri(A, B, work, s); \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int pivot_size = sizeof(int) * A.size(0) * A.size(1); \ + int matrix_size = sizeof(DType) * A.Size(); \ + int workspace_size = (pivot_size + matrix_size + \ + sizeof(DType) - 1) / sizeof(DType); \ + Tensor workspace = ctx.requested[0].\ + get_space_typed(Shape1(workspace_size), s); \ + const Tensor pivot(reinterpret_cast(workspace.dptr_), \ + Shape2(A.size(0), A.size(1))); \ + const Tensor temp(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + A.shape_); \ + Copy(temp, A, s); \ + linalg_batch_getrf(temp, pivot, true, s); \ + linalg_batch_getri(A, temp, pivot, s); \ } #else @@ -1485,9 +1488,8 @@ void linalg_batch_inverse(const Tensor& A, \ template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ - LOG(FATAL) << "batched getrf and getri requires CUDA version >= 8.0!"; \ + const mxnet::OpContext& ctx) { \ + LOG(FATAL) << "gpu matrix inversion requires CUDA version >= 8.0!"; \ } #endif // CUDA_VERSION >= 8000 diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index c7dc165a30ab..38eee3469c6b 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -458,17 +458,15 @@ struct inverse { template static void op(const Tensor& B, const Tensor& A, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - Stream *s = ctx.get_stream(); - // Reserve workspace (size determined by query) - int lwork(linalg_getri_workspace_query(A, s)); - Tensor work = ctx.requested[0] - .get_space_typed(Shape1(lwork), s); // Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose // A even if we are using the col-major version of getrf and getri routines. - linalg_batch_inverse(A, B, work, s); + linalg_batch_inverse(A, B, ctx); } }; +// partial pivoting LU decomposition: A = PLU, so det(A) = det(P)det(L)det(U) +// det(P) depends on number of row changes in P, det(L) = 1, det(U) = prod(diag(U)) +// this kernel computes sign(det(A)), log(abs(det(A))) struct SignedLogDet { template MSHADOW_XINLINE static void Map(int i, int N, int* pivot, @@ -478,7 +476,7 @@ struct SignedLogDet { DType diag_logsum(0); int *pivot_mat = pivot + i * N; DType *LU_mat = LU + i * N * N; - for ( int j = 0; j < N; ++j ) { + for (int j = 0; j < N; ++j) { changes += (pivot_mat[j] == (j + 1)); DType diag = LU_mat[j * (N + 1)]; diag_sign *= ((DType(0) < diag) - (diag < DType(0))); @@ -488,6 +486,7 @@ struct SignedLogDet { logdet[i] = diag_logsum; } }; + // det = det(A), LU and pivot store the LU decomposition output which will be // used in computing gradient struct det { @@ -499,9 +498,8 @@ struct det { Tensor sign = ctx.requested[0] .get_space_typed(det.shape_, s); Copy(LU, A, s); - for(index_t i = 0; i < A.size(0); ++i) { - linalg_getrf(LU[i], pivot[i], s); - } + // since det(A) = det(trans(A)), so we'll use col-major blas routines here + linalg_batch_getrf(LU, pivot, false, s); using namespace mxnet_op; using namespace mshadow::expr; Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, @@ -510,6 +508,43 @@ struct det { } }; +// logdet = log(det(A)) +struct logdet { + template + static void op(const Tensor& A, const Tensor& logdet, + const Tensor& LU, const Tensor& pivot, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + Tensor sign = ctx.requested[0] + .get_space_typed(logdet.shape_, s); + Copy(LU, A, s); + linalg_batch_getrf(LU, pivot, false, s); + using namespace mxnet_op; + using namespace mshadow::expr; + Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, + LU.dptr_, sign.dptr_, logdet.dptr_); + const_cast&>(logdet) = F(sign) + logdet; + } +}; + +// sign = sign(det(A)) +// logdet = log(abs(det(A))) +struct slogdet { + template + static void op(const Tensor& A, const Tensor& sign, + const Tensor& logdet, const Tensor& LU, + const Tensor& pivot, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + Copy(LU, A, s); + linalg_batch_getrf(LU, pivot, false, s); + using namespace mxnet_op; + using namespace mshadow::expr; + Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, + LU.dptr_, sign.dptr_, logdet.dptr_); + } +}; + // Backward operators (always using batch processing) struct gemm_backward { diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index bf56d3f17f9a..c58af0ae1db6 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -961,20 +961,101 @@ Examples:: // Batch matrix inversion A = [[[1., 4.], [2., 3.]], - [[1., 3.], [2., 4.]]] - det(A) = [-5., -2.] + [[2., 3.], [1., 4.]]] + det(A) = [-5., 5.] )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(3) .set_attr("FListInputNames", [](const NodeAttrs& attrs) - { return std::vector{"A", "LU", "pivot"}; }) + { return std::vector{"A"}; }) .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) -.set_attr("FInferShape", DetShape) -.set_attr("FInferType", DetType) +.set_attr("FInferShape", DetShape<1>) +.set_attr("FInferType", DetType<1>) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", LaOpDetForward) +.set_attr("FCompute", LaOpDetForward) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +NNVM_REGISTER_OP(_linalg_logdet) +.add_alias("linalg_logdet") +.describe(R"code(Compute the log determinant of a matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a square matrix. We compute: + + *out* = *log(det(A))* + +If *n>2*, *logdet* is performed separately on the trailing two dimensions +for all inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix inversion + A = [[2., 3.], [1., 4.]] + logdet(A) = [1.609438] + + // Batch matrix inversion + A = [[[2., 3.], [1., 4.]], + [[1., 2.], [2., 4.]], + [[1., 2.], [4., 3.]]] + logdet(A) = [1.609438, -inf, nan] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(3) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; }) +.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { + return 1; }) +.set_attr("FInferShape", DetShape<1>) +.set_attr("FInferType", DetType<1>) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FCompute", LaOpDetForward) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +NNVM_REGISTER_OP(_linalg_slogdet) +.add_alias("linalg_slogdet") +.describe(R"code(Compute the sign and log of the determinant of a matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a square matrix. We compute: + + *sign* = *sign(det(A))* + *logdet* = *log(abs(det(A)))* + +If *n>2*, *slogdet* is performed separately on the trailing two dimensions +for all inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix inversion + A = [[2., 3.], [1., 4.]] + sign, logdet = slogdet(A) + sign = [1.] + logdet = [1.609438] + + // Batch matrix inversion + A = [[[2., 3.], [1., 4.]], + [[1., 2.], [2., 4.]], + [[1., 2.], [4., 3.]]] + sign, logdet = slogdet(A) + sign = [1., 0., -1.] + logdet = [1.609438, -inf, 1.609438] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(4) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; }) +.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { + return 2; }) +.set_attr("FInferShape", DetShape<2>) +.set_attr("FInferType", DetType<2>) +.set_attr("FCompute", LaOpDetForward) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); } // namespace op diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 4f9c04efef21..f9eb2348d0ef 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -413,13 +413,13 @@ inline bool InverseShape(const nnvm::NodeAttrs& attrs, return true; } -// Shape inference function for linalg_det -// Inputs: A. Outputs: det(A) +// Shape inference function for det functions in linalg +template inline bool DetShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 1); - CHECK_EQ(out_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), onum + 2); const mxnet::TShape& in = (*in_attrs)[0]; const int ndim(in.ndim()); CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; @@ -430,12 +430,16 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, } else { out = mxnet::TShape(in.begin(), in.end() - 2); } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); /* det */ - SHAPE_ASSIGN_CHECK(*out_attrs, 1, in); /* LU */ - SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ + for (int i = 0; i < onum; ++i) { + SHAPE_ASSIGN_CHECK(*out_attrs, i, out); /* sign or det or logdet */ + } + SHAPE_ASSIGN_CHECK(*out_attrs, onum, in); /* LU */ + SHAPE_ASSIGN_CHECK(*out_attrs, onum + 1, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ return true; } +// Type inference function for det functions in linalg +template inline bool DetType(const nnvm::NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { @@ -445,9 +449,11 @@ inline bool DetType(const nnvm::NodeAttrs& attrs, CHECK_NE(dtype, -1) << "Input must have specified type"; out_type->clear(); - out_type->push_back(dtype); - out_type->push_back(dtype); - out_type->push_back(mshadow::kInt32); + for (int i = 0; i < onum; ++i) { + out_type->push_back(dtype); /* sign or det or logdet */ + } + out_type->push_back(dtype); /* LU */ + out_type->push_back(mshadow::kInt32); /* pivot */ return true; } @@ -791,8 +797,45 @@ void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs, }); } -// (A) => (det, LU, pivot) -template + +template +struct LaOpDetCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + CHECK(false) << "no specialized LaOpDetForward defined for template parameters"; + } +}; +template +struct LaOpDetCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + outputs[2].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpDetCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + outputs[2].FlatToKD(s), + outputs[3].FlatToKD(s), ctx, attrs); + } +}; + +template void LaOpDetForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -800,17 +843,12 @@ void LaOpDetForward(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 1); - CHECK_EQ(outputs.size(), 3); + CHECK_EQ(outputs.size(), onum + 2); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - mshadow::Stream *s = ctx.get_stream(); - laop::op(inputs[0].FlatToKD(s), - outputs[0].FlatToKD(s), - outputs[1].FlatToKD(s), - outputs[2].FlatToKD(s), ctx, attrs); + LaOpDetCaller::op(inputs, outputs, attrs, ctx); }); } - } // namespace op } // namespace mxnet From 8aa6547915098011ebbbd961bc165a2664d65281 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 06:28:19 +0000 Subject: [PATCH 04/30] fix --- src/operator/linalg_impl.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 4cd9ca46f1be..9b9922dec13b 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1353,6 +1353,7 @@ template int linalg_getri_workspace_query(const Tensor& A, \ Stream *s) { LOG(FATAL) << "it only takes float or double Tensor"; + return 0; } // Query workspace for "getri" @@ -1400,7 +1401,7 @@ void linalg_batch_getri(const Tensor& A, \ LU_ptr, LU.dptr_, \ LU.size(1) * LU.size(2)); \ CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), A.size(1), \ - const_cast(LU_ptr), B.size(2), \ + const_cast(LU_ptr), LU.size(2), \ const_cast(pivot.dptr_), A_ptr, A.size(2), \ static_cast(info.dptr), A.size(0))) \ Storage::Get()->Free(info); \ @@ -1468,18 +1469,18 @@ void linalg_batch_inverse(const Tensor& A, \ const mxnet::OpContext& ctx) { \ Stream *s = ctx.get_stream(); \ int pivot_size = sizeof(int) * A.size(0) * A.size(1); \ - int matrix_size = sizeof(DType) * A.Size(); \ + int matrix_size = sizeof(DType) * A.shape_.Size(); \ int workspace_size = (pivot_size + matrix_size + \ sizeof(DType) - 1) / sizeof(DType); \ Tensor workspace = ctx.requested[0].\ get_space_typed(Shape1(workspace_size), s); \ const Tensor pivot(reinterpret_cast(workspace.dptr_), \ Shape2(A.size(0), A.size(1))); \ - const Tensor temp(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ - A.shape_); \ - Copy(temp, A, s); \ - linalg_batch_getrf(temp, pivot, true, s); \ - linalg_batch_getri(A, temp, pivot, s); \ + const Tensor LU(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + A.shape_); \ + Copy(LU, B, s); \ + linalg_batch_getrf(LU, pivot, true, s); \ + linalg_batch_getri(A, LU, pivot, s); \ } #else From 77ecd98e85462bb8a7f53d48b72ff2c4da4ca433 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 06:35:55 +0000 Subject: [PATCH 05/30] register gpu det forward --- src/operator/tensor/la_op.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 3ef714e00c18..c80385697c14 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -99,6 +99,15 @@ NNVM_REGISTER_OP(_linalg_inverse) NNVM_REGISTER_OP(_backward_linalg_inverse) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_det) +.set_attr("FCompute", LaOpDetForward); + +NNVM_REGISTER_OP(_linalg_logdet) +.set_attr("FCompute", LaOpDetForward); + +NNVM_REGISTER_OP(_linalg_slogdet) +.set_attr("FCompute", LaOpDetForward); + #if MXNET_USE_CUSOLVER == 1 NNVM_REGISTER_OP(_linalg_potrf) From 4aa46204f2988e3bc6fb466b0999f378ebc8eac5 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 21:34:50 +0800 Subject: [PATCH 06/30] add gpu det backward --- src/operator/linalg.h | 18 ++++++-- src/operator/linalg_impl.h | 57 +++++++++++++++++++++++++ src/operator/tensor/la_op-inl.h | 23 +++++++++- src/operator/tensor/la_op.cc | 9 ++++ src/operator/tensor/la_op.h | 75 ++++++++++++++++++++++++++++++--- 5 files changed, 173 insertions(+), 9 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index 25f08d6f1d34..7a65cc5ff4be 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -197,8 +197,8 @@ int linalg_syevd_workspace_query(const Tensor& A, // LAPACK documentation for further details. // Note: // - A is input and output parameter (overwritten by LU) -// - Param check_singular is only useful in cpu version. -// If check_singular is false, don't throw error when A is in-convertible matrix. +// - Param check_singular is only useful in cpu version. If check_singular is false, +// don't throw error when A is non-invertible matrix. template void linalg_getrf(const Tensor& A, @@ -230,7 +230,7 @@ template void linalg_batch_getri(const Tensor& A, const Tensor& LU, const Tensor& pivot, - Stream *s = 0); + Stream *s = 0); //////////////////////////////// INVERSE //////////////////////////////////////////// @@ -241,6 +241,18 @@ void linalg_batch_inverse(const Tensor& A, const Tensor& B, const mxnet::OpContext& ctx); +//////////////////////////////// DET //////////////////////////////////////////// + +// CPU/GPU-versions of helper functions to compute matrix determinant +// Compute matrix inversion with LU and pivot using temp workspace, +// the result stores back to LU +template +void linalg_batch_det_helper(const Tensor& LU, + const Tensor& pivot, + const Tensor& det, + const Tensor& temp, + const mxnet::OpContext& ctx); + #include "linalg_impl.h" #endif // MXNET_OPERATOR_LINALG_H_ diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 9b9922dec13b..988ef5430859 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1500,4 +1500,61 @@ LINALG_GPU_BATCH_INVERSE(gpu, double) #endif // __CUDACC__ +//////////////////////////////// DET //////////////////////////////////////////// + +// CPU/GPU-versions of helper functions to compute matrix determinant + +#define LINALG_CPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int lwork(linalg_getri_workspace_query(LU[0], s)); \ + Tensor work = ctx.requested[0].\ + get_space_typed(Shape1(lwork), s); \ + for (index_t i = 0; i < LU.size(0); ++i) { \ + if (det[i] != DType(0)) { \ + linalg_getri(LU[i], pivot[i], work, s); \ + } \ + } \ +} + +LINALG_CPU_BATCH_DET_HELPER(cpu, float) +LINALG_CPU_BATCH_DET_HELPER(cpu, double) + +// GETRF and GETRI only available with cuda8 or higher. +#if CUDA_VERSION >= 8000 + +#define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + linalg_batch_getri(temp, LU, pivot, s); \ + Copy(LU, temp, s); \ +} + +#else + +#define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const mxnet::OpContext& ctx) { \ + LOG(FATAL) << "gpu matrix inversion requires CUDA version >= 8.0!"; \ +} + +#endif // CUDA_VERSION >= 8000 + +LINALG_GPU_BATCH_DET_HELPER(gpu, float) +LINALG_GPU_BATCH_DET_HELPER(gpu, double) + #endif // MXNET_OPERATOR_LINALG_IMPL_H_ diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 38eee3469c6b..f702303ccbfb 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -477,7 +477,7 @@ struct SignedLogDet { int *pivot_mat = pivot + i * N; DType *LU_mat = LU + i * N * N; for (int j = 0; j < N; ++j) { - changes += (pivot_mat[j] == (j + 1)); + changes += (pivot_mat[j] != (j + 1)); DType diag = LU_mat[j * (N + 1)]; diag_sign *= ((DType(0) < diag) - (diag < DType(0))); diag_logsum += std::log(std::abs(diag)); @@ -901,6 +901,27 @@ struct inverse_backward { } }; +// Backward of det(A) is derived from Jacobi's formula. +// The closed form solution is pretty easy when A is invertible. +// TODO(arcadiaphy) add implementation for non-invertible case +struct det_backward { + template + static void op(const Tensor& ddet, + const Tensor& det, + const Tensor& LU, + const Tensor& pivot, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mshadow; + using namespace mshadow::expr; + // compute inverse(A) and stores it to LU + linalg_batch_det_helper(LU, pivot, det, dA, ctx); + const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ + Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ + transpose(LU, Shape3(0, 2, 1)); + } +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index c58af0ae1db6..9baf8a675a89 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -975,8 +975,17 @@ Examples:: .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", LaOpDetForward) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_det"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); +NNVM_REGISTER_OP(_backward_linalg_det) +.set_num_inputs(6) +.set_num_outputs(1) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpDetBackward); + NNVM_REGISTER_OP(_linalg_logdet) .add_alias("linalg_logdet") .describe(R"code(Compute the log determinant of a matrix. diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index f9eb2348d0ef..bc3180dfd16c 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -799,7 +799,7 @@ void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs, template -struct LaOpDetCaller { +struct LaOpDetForwardCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, @@ -808,7 +808,7 @@ struct LaOpDetCaller { } }; template -struct LaOpDetCaller { +struct LaOpDetForwardCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, @@ -821,7 +821,7 @@ struct LaOpDetCaller { } }; template -struct LaOpDetCaller { +struct LaOpDetForwardCaller { static void op(const std::vector& inputs, const std::vector& outputs, const nnvm::NodeAttrs& attrs, @@ -834,7 +834,6 @@ struct LaOpDetCaller { outputs[3].FlatToKD(s), ctx, attrs); } }; - template void LaOpDetForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -845,7 +844,73 @@ void LaOpDetForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1); CHECK_EQ(outputs.size(), onum + 2); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - LaOpDetCaller::op(inputs, outputs, attrs, ctx); + LaOpDetForwardCaller::op(inputs, outputs, attrs, ctx); + }); +} + +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + CHECK(false) << "no specialized LaOpDetBackward defined for template parameters"; + } +}; +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + inputs[3].FlatToKD(s), + inputs[4].FlatToKD(s), + inputs[5].FlatToKD(s), + outputs[0].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + inputs[4].FlatToKD(s), + inputs[5].FlatToKD(s), + inputs[6].FlatToKD(s), + inputs[7].FlatToKD(s), + outputs[1].FlatToKD(s), ctx, attrs); + } +}; +template +void LaOpDetBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), (onum + 2) * 2); + CHECK_EQ(outputs.size(), onum); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + std::vector tspace(outputs); + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + tspace[i].dptr_ = ctx.requested[0] + .get_space_typed(Shape1(outputs[i].Size()), s).dptr_; + } + } + LaOpDetBackwardCaller::op(inputs, tspace, attrs, ctx); + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + Tensor out = outputs[i].FlatTo1D(s); + out += tspace[i].FlatTo1D(s); + } + } }); } From 324476c02eb5f09683bbadeb5ead67d02b9b25a6 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 21:38:47 +0800 Subject: [PATCH 07/30] register gpu det backward --- src/operator/tensor/la_op.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index c80385697c14..efb8d1dc2fc4 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -102,6 +102,9 @@ NNVM_REGISTER_OP(_backward_linalg_inverse) NNVM_REGISTER_OP(_linalg_det) .set_attr("FCompute", LaOpDetForward); +NNVM_REGISTER_OP(_linalg_det_backward) +.set_attr("FCompute", LaOpDetBackward); + NNVM_REGISTER_OP(_linalg_logdet) .set_attr("FCompute", LaOpDetForward); From cedff671c59960b5d703cd7da31ac70f653387bc Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sun, 19 May 2019 21:42:10 +0800 Subject: [PATCH 08/30] fix --- src/operator/tensor/la_op.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index efb8d1dc2fc4..9217cd78a583 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -102,7 +102,7 @@ NNVM_REGISTER_OP(_backward_linalg_inverse) NNVM_REGISTER_OP(_linalg_det) .set_attr("FCompute", LaOpDetForward); -NNVM_REGISTER_OP(_linalg_det_backward) +NNVM_REGISTER_OP(_backward_linalg_det) .set_attr("FCompute", LaOpDetBackward); NNVM_REGISTER_OP(_linalg_logdet) From 46bd8ae65a6ff760989a4a77484e0d88aaabae6e Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 09:21:15 +0800 Subject: [PATCH 09/30] add logdet slogdet backward --- src/operator/linalg.h | 1 + src/operator/linalg_impl.h | 5 ++- src/operator/tensor/la_op-inl.h | 55 ++++++++++++++++++++++++++++++--- src/operator/tensor/la_op.cc | 19 ++++++++++++ src/operator/tensor/la_op.cu | 6 ++++ src/operator/tensor/la_op.h | 6 ++-- 6 files changed, 84 insertions(+), 8 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index 7a65cc5ff4be..f3536fe19ff6 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -251,6 +251,7 @@ void linalg_batch_det_helper(const Tensor& LU, const Tensor& pivot, const Tensor& det, const Tensor& temp, + const DType zero_det, const mxnet::OpContext& ctx); #include "linalg_impl.h" diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 988ef5430859..c0d46bad8a76 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1510,13 +1510,14 @@ void linalg_batch_det_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ + const DType zero_det, \ const mxnet::OpContext& ctx) { \ Stream *s = ctx.get_stream(); \ int lwork(linalg_getri_workspace_query(LU[0], s)); \ Tensor work = ctx.requested[0].\ get_space_typed(Shape1(lwork), s); \ for (index_t i = 0; i < LU.size(0); ++i) { \ - if (det[i] != DType(0)) { \ + if (det[i] != zero_det) { \ linalg_getri(LU[i], pivot[i], work, s); \ } \ } \ @@ -1534,6 +1535,7 @@ void linalg_batch_det_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ + const DType zero_det, \ const mxnet::OpContext& ctx) { \ Stream *s = ctx.get_stream(); \ linalg_batch_getri(temp, LU, pivot, s); \ @@ -1548,6 +1550,7 @@ void linalg_batch_det_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ + const DType zero_det, \ const mxnet::OpContext& ctx) { \ LOG(FATAL) << "gpu matrix inversion requires CUDA version >= 8.0!"; \ } diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index f702303ccbfb..8cf112c07d56 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -528,11 +528,11 @@ struct logdet { }; // sign = sign(det(A)) -// logdet = log(abs(det(A))) +// logabsdet = log(abs(det(A))) struct slogdet { template static void op(const Tensor& A, const Tensor& sign, - const Tensor& logdet, const Tensor& LU, + const Tensor& logabsdet, const Tensor& LU, const Tensor& pivot, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream *s = ctx.get_stream(); @@ -541,7 +541,7 @@ struct slogdet { using namespace mxnet_op; using namespace mshadow::expr; Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, - LU.dptr_, sign.dptr_, logdet.dptr_); + LU.dptr_, sign.dptr_, logabsdet.dptr_); } }; @@ -903,6 +903,7 @@ struct inverse_backward { // Backward of det(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. +// For non-invertible A, grad is not backwarded now. // TODO(arcadiaphy) add implementation for non-invertible case struct det_backward { template @@ -915,13 +916,59 @@ struct det_backward { using namespace mshadow; using namespace mshadow::expr; // compute inverse(A) and stores it to LU - linalg_batch_det_helper(LU, pivot, det, dA, ctx); + linalg_batch_det_helper(LU, pivot, det, dA, DType(0), ctx); const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); } }; +// Backward of logdet(A) is derived from Jacobi's formula. +// The closed form solution is pretty easy when A is invertible. +// For non-invertible A, grad is not backwarded now. +// TODO(arcadiaphy) add implementation for non-invertible case +struct logdet_backward { + template + static void op(const Tensor& dlogdet, + const Tensor& logdet, + const Tensor& LU, + const Tensor& pivot, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mshadow; + using namespace mshadow::expr; + // compute inverse(A) and stores it to LU + linalg_batch_det_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx); + const_cast&>(dA) = broadcast_to(reshape(dlogdet, \ + Shape3(logdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ + transpose(LU, Shape3(0, 2, 1)); + } +}; + +// Backward of slogdet(A) is derived from Jacobi's formula. +// The closed form solution is pretty easy when A is invertible. +// For non-invertible A, grad is not backwarded now. +// Grad is not properly defined on sign, so it's not backwarded either. +// TODO(arcadiaphy) add implementation for non-invertible case +struct slogdet_backward { + template + static void op(const Tensor& dlogabsdet, + const Tensor& sign, + const Tensor& logabsdet, + const Tensor& LU, + const Tensor& pivot, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mshadow; + using namespace mshadow::expr; + // compute inverse(A) and stores it to LU + linalg_batch_det_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); + const_cast&>(dA) = broadcast_to(reshape(dlogabsdet, \ + Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ + transpose(LU, Shape3(0, 2, 1)); + } +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 9baf8a675a89..1d9b62fcb2ca 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -1023,8 +1023,17 @@ Examples:: .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", LaOpDetForward) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_logdet"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); +NNVM_REGISTER_OP(_backward_linalg_logdet) +.set_num_inputs(6) +.set_num_outputs(1) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpDetBackward); + NNVM_REGISTER_OP(_linalg_slogdet) .add_alias("linalg_slogdet") .describe(R"code(Compute the sign and log of the determinant of a matrix. @@ -1039,6 +1048,7 @@ If *n>2*, *slogdet* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. +.. note:: The gradient is not properly defined on sign, so it's ignored. Examples:: @@ -1065,7 +1075,16 @@ Examples:: .set_attr("FInferShape", DetShape<2>) .set_attr("FInferType", DetType<2>) .set_attr("FCompute", LaOpDetForward) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_slogdet"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); +NNVM_REGISTER_OP(_backward_linalg_slogdet) +.set_num_inputs(8) +.set_num_outputs(1) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpDetBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 9217cd78a583..9b6760ee828a 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -108,9 +108,15 @@ NNVM_REGISTER_OP(_backward_linalg_det) NNVM_REGISTER_OP(_linalg_logdet) .set_attr("FCompute", LaOpDetForward); +NNVM_REGISTER_OP(_backward_linalg_logdet) +.set_attr("FCompute", LaOpDetBackward); + NNVM_REGISTER_OP(_linalg_slogdet) .set_attr("FCompute", LaOpDetForward); +NNVM_REGISTER_OP(_backward_linalg_slogdet) +.set_attr("FCompute", LaOpDetBackward); + #if MXNET_USE_CUSOLVER == 1 NNVM_REGISTER_OP(_linalg_potrf) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index bc3180dfd16c..a28aa74e8651 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -878,12 +878,12 @@ struct LaOpDetBackwardCaller { const nnvm::NodeAttrs& attrs, const OpContext& ctx) { mshadow::Stream *s = ctx.get_stream(); - laop::op(inputs[0].FlatToKD(s), + laop::op(inputs[1].FlatToKD(s), inputs[4].FlatToKD(s), inputs[5].FlatToKD(s), inputs[6].FlatToKD(s), inputs[7].FlatToKD(s), - outputs[1].FlatToKD(s), ctx, attrs); + outputs[0].FlatToKD(s), ctx, attrs); } }; template @@ -895,7 +895,7 @@ void LaOpDetBackward(const nnvm::NodeAttrs& attrs, using namespace mshadow; Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), (onum + 2) * 2); - CHECK_EQ(outputs.size(), onum); + CHECK_EQ(outputs.size(), 1); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { std::vector tspace(outputs); for ( int i = 0; i < onum; ++i ) { From 41fac4c7b85cc7641dcdf2901159dcafcf673bc7 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 11:58:38 +0800 Subject: [PATCH 10/30] stop grad for zero det --- src/operator/tensor/la_op-inl.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 8cf112c07d56..729d30975d95 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -901,6 +901,17 @@ struct inverse_backward { } }; +// Here we set grad to zero if det = 0 as a temporary method +struct StopZeroDetGrad { + template + MSHADOW_XINLINE static void Map(int i, int grad_step, DType *grad, DType *det, DType zero_det) { + int batch_ind = i / grad_step; + if (det[batch_ind] == zero_det) { + grad[i] = DType(0); + } + } +}; + // Backward of det(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. // For non-invertible A, grad is not backwarded now. @@ -915,11 +926,16 @@ struct det_backward { const OpContext& ctx, const nnvm::NodeAttrs& attrs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; // compute inverse(A) and stores it to LU linalg_batch_det_helper(LU, pivot, det, dA, DType(0), ctx); const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); + Stream *s = ctx.get_stream(); + // stop grad for zero det temporarily + Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ + dA.dptr_, det.dptr_, DType(0)); } }; @@ -937,11 +953,16 @@ struct logdet_backward { const OpContext& ctx, const nnvm::NodeAttrs& attrs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; // compute inverse(A) and stores it to LU linalg_batch_det_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx); const_cast&>(dA) = broadcast_to(reshape(dlogdet, \ Shape3(logdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); + Stream *s = ctx.get_stream(); + // stop grad for zero det temporarily + Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ + dA.dptr_, logdet.dptr_, DType(-INFINITY)); } }; @@ -961,11 +982,16 @@ struct slogdet_backward { const OpContext& ctx, const nnvm::NodeAttrs& attrs) { using namespace mshadow; using namespace mshadow::expr; + using namespace mxnet_op; // compute inverse(A) and stores it to LU linalg_batch_det_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); const_cast&>(dA) = broadcast_to(reshape(dlogabsdet, \ Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); + Stream *s = ctx.get_stream(); + // stop grad for zero det temporarily + Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ + dA.dptr_, logabsdet.dptr_, DType(-INFINITY)); } }; From 8a886d5b0f59a9e9383963fc82fc94121466b269 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 12:07:33 +0800 Subject: [PATCH 11/30] fix --- src/operator/tensor/la_op.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 9b6760ee828a..74fc97b75d6b 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -115,7 +115,7 @@ NNVM_REGISTER_OP(_linalg_slogdet) .set_attr("FCompute", LaOpDetForward); NNVM_REGISTER_OP(_backward_linalg_slogdet) -.set_attr("FCompute", LaOpDetBackward); +.set_attr("FCompute", LaOpDetBackward); #if MXNET_USE_CUSOLVER == 1 From 3f9b4f0357e92632ed04c96ebdd418f03809a5a3 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 12:12:23 +0800 Subject: [PATCH 12/30] fix --- src/operator/linalg_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index c0d46bad8a76..739b445b852d 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1249,7 +1249,7 @@ void linalg_getrf(const Tensor& A, \ A.dptr_, A.stride_, pivot.dptr_)); \ CHECK_GE(ret, 0) << #fname << " failed in lapack on cpu."; \ if (check_singular) { \ - CHECK_EQ(ret, 0) << "the input matrix is in-convertible"; \ + CHECK_EQ(ret, 0) << "the input matrix is non-convertible"; \ } \ } From a7da4903d479822c6379da7f652a5a46b6ec5285 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 16:25:26 +0800 Subject: [PATCH 13/30] reduce grad transfer --- src/operator/tensor/la_op.cc | 15 ++++++++------- src/operator/tensor/la_op.h | 34 +++++++++++++++++++++++++--------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 1d9b62fcb2ca..748db50b1cb0 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -975,11 +975,11 @@ Examples:: .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", LaOpDetForward) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_det"}) +.set_attr("FGradient", ReduceDetGrad<1>{"_backward_linalg_det"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); NNVM_REGISTER_OP(_backward_linalg_det) -.set_num_inputs(6) +.set_num_inputs(4) .set_num_outputs(1) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -1023,11 +1023,11 @@ Examples:: .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", LaOpDetForward) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_logdet"}) +.set_attr("FGradient", ReduceDetGrad<1>{"_backward_linalg_logdet"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); NNVM_REGISTER_OP(_backward_linalg_logdet) -.set_num_inputs(6) +.set_num_inputs(4) .set_num_outputs(1) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -1048,7 +1048,8 @@ If *n>2*, *slogdet* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. -.. note:: The gradient is not properly defined on sign, so it's ignored. +.. note:: The gradient is not properly defined on sign, so it's not allowed + to pass gradient on it. Examples:: @@ -1075,11 +1076,11 @@ Examples:: .set_attr("FInferShape", DetShape<2>) .set_attr("FInferType", DetType<2>) .set_attr("FCompute", LaOpDetForward) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_slogdet"}) +.set_attr("FGradient", ReduceDetGrad<2>{"_backward_linalg_slogdet"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); NNVM_REGISTER_OP(_backward_linalg_slogdet) -.set_num_inputs(8) +.set_num_inputs(5) .set_num_outputs(1) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index a28aa74e8651..50848fb2c595 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -865,9 +865,9 @@ struct LaOpDetBackwardCaller { const OpContext& ctx) { mshadow::Stream *s = ctx.get_stream(); laop::op(inputs[0].FlatToKD(s), - inputs[3].FlatToKD(s), - inputs[4].FlatToKD(s), - inputs[5].FlatToKD(s), + inputs[1].FlatToKD(s), + inputs[2].FlatToKD(s), + inputs[3].FlatToKD(s), outputs[0].FlatToKD(s), ctx, attrs); } }; @@ -878,11 +878,11 @@ struct LaOpDetBackwardCaller { const nnvm::NodeAttrs& attrs, const OpContext& ctx) { mshadow::Stream *s = ctx.get_stream(); - laop::op(inputs[1].FlatToKD(s), - inputs[4].FlatToKD(s), - inputs[5].FlatToKD(s), - inputs[6].FlatToKD(s), - inputs[7].FlatToKD(s), + laop::op(inputs[0].FlatToKD(s), + inputs[1].FlatToKD(s), + inputs[2].FlatToKD(s), + inputs[3].FlatToKD(s), + inputs[4].FlatToKD(s), outputs[0].FlatToKD(s), ctx, attrs); } }; @@ -894,7 +894,7 @@ void LaOpDetBackward(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mshadow; Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), (onum + 2) * 2); + CHECK_EQ(inputs.size(), onum + 3); CHECK_EQ(outputs.size(), 1); MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { std::vector tspace(outputs); @@ -914,6 +914,22 @@ void LaOpDetBackward(const nnvm::NodeAttrs& attrs, }); } +// Only transfer ddet and outputs to gradient +template +struct ReduceDetGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector heads; + heads.push_back(ograds[onum - 1]); + uint32_t n_out = n->num_outputs(); + for (uint32_t i = 0; i < n_out; ++i) { + heads.emplace_back(nnvm::NodeEntry{n, i, 0}); + } + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + } // namespace op } // namespace mxnet From 56a2e331a8682dfcc21e2a3e9789e85090cc135f Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 16:40:12 +0800 Subject: [PATCH 14/30] fix docs --- src/operator/tensor/la_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 748db50b1cb0..3da49667d5d4 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -1042,7 +1042,7 @@ Input is a tensor *A* of dimension *n >= 2*. If *n=2*, *A* is a square matrix. We compute: *sign* = *sign(det(A))* - *logdet* = *log(abs(det(A)))* + *logabsdet* = *log(abs(det(A)))* If *n>2*, *slogdet* is performed separately on the trailing two dimensions for all inputs (batch mode). @@ -1055,17 +1055,17 @@ Examples:: // Single matrix inversion A = [[2., 3.], [1., 4.]] - sign, logdet = slogdet(A) + sign, logabsdet = slogdet(A) sign = [1.] - logdet = [1.609438] + logabsdet = [1.609438] // Batch matrix inversion A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] - sign, logdet = slogdet(A) + sign, logabsdet = slogdet(A) sign = [1., 0., -1.] - logdet = [1.609438, -inf, 1.609438] + logabsdet = [1.609438, -inf, 1.609438] )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(4) From a0ca56cd5d36be2ebe52b0d11d299f800904c3fb Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 22:26:46 +0800 Subject: [PATCH 15/30] update comments --- src/operator/linalg.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index f3536fe19ff6..c5ce8160e871 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -195,11 +195,11 @@ int linalg_syevd_workspace_query(const Tensor& A, // CPU/GPU-versions of LAPACK function "getrf". Please refer to the // LAPACK documentation for further details. + // Note: // - A is input and output parameter (overwritten by LU) // - Param check_singular is only useful in cpu version. If check_singular is false, // don't throw error when A is non-invertible matrix. - template void linalg_getrf(const Tensor& A, const Tensor& pivot, @@ -216,16 +216,19 @@ void linalg_batch_getrf(const Tensor& A, // CPU/GPU-versions of LAPACK function "getri". Please refer to the // LAPACK documentation for further details. + // Note: // - pivot and LU is the output of getrf(A) // - LU is also the output parameter (overwritten by inverse(A)) - template void linalg_getri(const Tensor& LU, const Tensor& pivot, \ const Tensor& work, Stream *s = 0); +// Note that this function only implements GPU version with "getriBatched" in cuBLAS. +// Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix +// inversion is stored in A. template void linalg_batch_getri(const Tensor& A, const Tensor& LU, From 1084b616ca55ab0b0ccea901addf0092000bc906 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 22:39:52 +0800 Subject: [PATCH 16/30] fix docs --- src/operator/tensor/la_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 3da49667d5d4..c6680bedd36a 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -1048,8 +1048,8 @@ If *n>2*, *slogdet* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. -.. note:: The gradient is not properly defined on sign, so it's not allowed - to pass gradient on it. +.. note:: The gradient is not properly defined on sign, so the gradient of + it is not backwarded. Examples:: From a46ba031c96da337fcbf5030b65e72280e23fab4 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 20 May 2019 23:17:48 +0800 Subject: [PATCH 17/30] fix lint --- src/operator/linalg_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 739b445b852d..f6c8fd42d33a 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1263,7 +1263,7 @@ void linalg_batch_getrf(const Tensor& A, \ const Tensor& pivot, \ bool check_singular, \ Stream *s) { \ - for(index_t i = 0; i < A.size(0); ++i) { \ + for (index_t i = 0; i < A.size(0); ++i) { \ linalg_getrf(A[i], pivot[i], check_singular); \ } \ } From a701edb7d3f4ff62c4b904aab5e38426b70e261f Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 14:49:05 +0800 Subject: [PATCH 18/30] add test --- tests/python/unittest/test_operator.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b5aa06964b29..1601108d5734 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6515,8 +6515,13 @@ def test_laop_6(): check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, atol=atol_bw, dtype=dtype) - a = np.sqrt(np.arange(4 * 4)).reshape(4, 4) + ## det(I + dot(v, v.T)) = 1 + dot(v.T, v) >= 1, so it's always invertible; + ## det is away from zero, so the value of logdet is stable + v = np.random.random(4) + a = np.eye(4) + np.outer(v, v) a = np.tile(a, (3, 1, 1)) + + # test matrix inverse r = np.eye(4) r = np.tile(r, (3, 1, 1)) test_inverse = mx.sym.linalg.inverse(data) @@ -6524,6 +6529,23 @@ def test_laop_6(): check_fw(test_eye, [a], [r]) check_grad(test_inverse, [a]) + # test matrix determinant + # det + r = np.linalg.det(a) + test_det = mx.sym.linalg.det(data) + check_fw(test_det, [a], [r]) + check_grad(test_det, [a]) + # logdet + r = np.log(np.linalg.det(a)) + test_logdet = mx.sym.linalg.logdet(data) + check_fw(test_logdet, [a], [r]) + check_grad(test_logdet, [a]) + # test slogdet + r = np.log(np.abs(np.linalg.det(a))) + _, test_slogdet = mx.sym.linalg.slogdet(data) + check_fw(test_slogdet, [a], [r]) + check_grad(test_slogdet, [a]) + @with_seed() def test_stack(): for _ in range(100): From cd3858f492b8aeb6a4c08aae97daad5f59cb1e9f Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 16:06:23 +0800 Subject: [PATCH 19/30] update docs --- docs/api/python/symbol/linalg.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md index 436bab78c451..f565d7cec28e 100644 --- a/docs/api/python/symbol/linalg.md +++ b/docs/api/python/symbol/linalg.md @@ -60,6 +60,9 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p extracttrian maketrian inverse + det + logdet + slogdet ``` ## API Reference From a632ac6b2f2f6fbee0cce6c80ddb02aefd137410 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 16:17:05 +0800 Subject: [PATCH 20/30] add operator --- python/mxnet/contrib/amp/lists/symbol.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 2f8b4f0f9a6a..8844774ab28b 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -433,6 +433,9 @@ '_linalg_maketrian', '_linalg_extracttrian', '_linalg_inverse', + '_linalg_det', + '_linalg_logdet', + '_linalg_slogdet', 'linalg_syrk', 'linalg_potrf', 'linalg_potri', @@ -446,6 +449,9 @@ 'linalg_maketrian', 'linalg_extracttrian', 'linalg_inverse', + 'linalg_det', + 'linalg_logdet', + 'linalg_slogdet', '_NDArray', '_Native', '_contrib_count_sketch', From df9bcc695f9f0bad691ccaf208369a89d60df952 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 16:32:06 +0800 Subject: [PATCH 21/30] update test --- tests/python/unittest/test_operator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1601108d5734..764f0e59a6d7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6520,6 +6520,7 @@ def test_laop_6(): v = np.random.random(4) a = np.eye(4) + np.outer(v, v) a = np.tile(a, (3, 1, 1)) + permute_mat = np.eye(4)[[1, 0, 2, 3]] # test matrix inverse r = np.eye(4) @@ -6541,10 +6542,13 @@ def test_laop_6(): check_fw(test_logdet, [a], [r]) check_grad(test_logdet, [a]) # test slogdet - r = np.log(np.abs(np.linalg.det(a))) - _, test_slogdet = mx.sym.linalg.slogdet(data) - check_fw(test_slogdet, [a], [r]) - check_grad(test_slogdet, [a]) + r1 = np.array([1., 1., 1.]) + r2 = np.log(np.abs(np.linalg.det(a))) + test_sign, test_logabsdet = mx.sym.linalg.slogdet(data) + check_fw(test_sign, [a], [r1]) + check_fw(test_sign, [np.dot(a, permute_mat)], [-r1]) + check_fw(test_logabsdet, [a], [r2]) + check_grad(test_logabsdet, [a]) @with_seed() def test_stack(): From 4855b8e8d559776bbf47cb494638d4bbde15aa6f Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 21 May 2019 19:25:43 +0800 Subject: [PATCH 22/30] trigger CI From ece449fb273b963da9636a61403b3df594c70f63 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 23 May 2019 01:12:58 +0800 Subject: [PATCH 23/30] remove slash --- src/operator/tensor/la_op.cc | 72 ++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index c6680bedd36a..dcb0728337a6 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -73,14 +73,14 @@ Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedu Examples:: - // Single matrix multiply-add + Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] - // Batch matrix multiply-add + Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] @@ -149,13 +149,13 @@ Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedu Examples:: - // Single matrix multiply + Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] - // Batch matrix multiply + Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) @@ -204,11 +204,11 @@ If *n>2*, *potrf* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix factorization + Single matrix factorization A = [[4.0, 1.0], [1.0, 4.25]] potrf(A) = [[2.0, 0], [0.5, 2.0]] - // Batch matrix factorization + Batch matrix factorization A = [[[4.0, 1.0], [1.0, 4.25]], [[16.0, 4.0], [4.0, 17.0]]] potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] )code" ADD_FILELINE) @@ -264,11 +264,11 @@ If *n>2*, *potri* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix inverse + Single matrix inverse A = [[2.0, 0], [0.5, 2.0]] potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]] - // Batch matrix inverse + Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], [[0.06641, -0.01562], [-0.01562, 0,0625]]] @@ -320,12 +320,12 @@ If *n>2*, *trmm* is performed separately on the trailing two dimensions for all Examples:: - // Single triangular matrix multiply + Single triangular matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] - // Batch triangular matrix multiply + Batch triangular matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], @@ -382,12 +382,12 @@ If *n>2*, *trsm* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix solve + Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] trsm(A, B, alpha=0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - // Batch matrix solve + Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] @@ -435,11 +435,11 @@ inputs (batch mode). Examples:: - // Single matrix reduction + Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] sumlogdiag(A) = [1.9459] - // Batch matrix reduction + Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] sumlogdiag(A) = [1.9459, 3.9318] )code" ADD_FILELINE) @@ -476,7 +476,7 @@ If *n>2*, then *A* represents a batch of square matrices on the trailing two dim Examples:: - // Single matrix diagonal extraction + Single matrix diagonal extraction A = [[1.0, 2.0], [3.0, 4.0]] @@ -484,7 +484,7 @@ Examples:: extractdiag(A, 1) = [2.0] - // Batch matrix diagonal extraction + Batch matrix diagonal extraction A = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], @@ -526,7 +526,7 @@ If *n>1*, then *A* represents a batch of diagonals of square matrices. The batch Examples:: - // Single diagonal matrix construction + Single diagonal matrix construction A = [1.0, 2.0] makediag(A) = [[1.0, 0.0], @@ -536,7 +536,7 @@ Examples:: [0.0, 0.0, 2.0], [0.0, 0.0, 0.0]] - // Batch diagonal matrix construction + Batch diagonal matrix construction A = [[1.0, 2.0], [3.0, 4.0]] @@ -585,7 +585,7 @@ The *offset* and *lower* parameters determine the triangle to be extracted: Examples:: - // Single triagonal extraction + Single triagonal extraction A = [[1.0, 2.0], [3.0, 4.0]] @@ -594,7 +594,7 @@ Examples:: extracttrian(A, 1) = [2.0] extracttrian(A, -1) = [3.0] - // Batch triagonal extraction + Batch triagonal extraction A = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], @@ -639,7 +639,7 @@ If *n>1*, then *A* represents a batch of triangular sub-matrices. The batch of c Examples:: - // Single matrix construction + Single matrix construction A = [1.0, 2.0, 3.0] maketrian(A) = [[1.0, 0.0], @@ -655,7 +655,7 @@ Examples:: [1.0, 0.0, 0.0], [2.0, 3.0, 0.0]] - // Batch matrix construction + Batch matrix construction A = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] @@ -714,7 +714,7 @@ inputs (batch mode). Examples:: - // Single matrix multiply + Single matrix multiply A = [[1., 2., 3.], [4., 5., 6.]] syrk(A, alpha=1., transpose=False) = [[14., 32.], @@ -724,7 +724,7 @@ Examples:: [22., 29., 36.], [27., 36., 45.]] - // Batch matrix multiply + Batch matrix multiply A = [[[1., 1.]], [[0.1, 0.1]]] syrk(A, alpha=2., transpose=False) = [[[4.]], [[0.04]]] )code" ADD_FILELINE) @@ -775,7 +775,7 @@ inputs (batch mode). Examples:: - // Single LQ factorization + Single LQ factorization A = [[1., 2., 3.], [4., 5., 6.]] Q, L = gelqf(A) Q = [[-0.26726124, -0.53452248, -0.80178373], @@ -783,7 +783,7 @@ Examples:: L = [[-3.74165739, 0.], [-8.55235974, 1.96396101]] - // Batch LQ factorization + Batch LQ factorization A = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]] Q, L = gelqf(A) @@ -847,14 +847,14 @@ mode). In this case, *U* has *n* dimensions like *A*, and *L* has *n-1* dimensio Examples:: - // Single symmetric eigendecomposition + Single symmetric eigendecomposition A = [[1., 2.], [2., 4.]] U, L = syevd(A) U = [[0.89442719, -0.4472136], [0.4472136, 0.89442719]] L = [0., 5.] - // Batch symmetric eigendecomposition + Batch symmetric eigendecomposition A = [[[1., 2.], [2., 4.]], [[1., 2.], [2., 5.]]] U, L = syevd(A) @@ -905,11 +905,11 @@ for all inputs (batch mode). Examples:: - // Single matrix inversion + Single matrix inversion A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]] - // Batch matrix inversion + Batch matrix inversion A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], @@ -955,11 +955,11 @@ for all inputs (batch mode). Examples:: - // Single matrix inversion + Single matrix inversion A = [[1., 4.], [2., 3.]] det(A) = [-5.] - // Batch matrix inversion + Batch matrix inversion A = [[[1., 4.], [2., 3.]], [[2., 3.], [1., 4.]]] det(A) = [-5., 5.] @@ -1002,11 +1002,11 @@ for all inputs (batch mode). Examples:: - // Single matrix inversion + Single matrix inversion A = [[2., 3.], [1., 4.]] logdet(A) = [1.609438] - // Batch matrix inversion + Batch matrix inversion A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] @@ -1053,13 +1053,13 @@ for all inputs (batch mode). Examples:: - // Single matrix inversion + Single matrix inversion A = [[2., 3.], [1., 4.]] sign, logabsdet = slogdet(A) sign = [1.] logabsdet = [1.609438] - // Batch matrix inversion + Batch matrix inversion A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] From cd23f6aaf6ea724376a4a81e13cd81eeea426a34 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 19:34:15 +0800 Subject: [PATCH 24/30] update operator check --- src/operator/tensor/la_op.h | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 50848fb2c595..e024693e3819 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -406,11 +406,12 @@ inline bool InverseShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const mxnet::TShape& in = (*in_attrs)[0]; + if (!ndim_is_known(in)) return false; const int ndim(in.ndim()); CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; SHAPE_ASSIGN_CHECK(*out_attrs, 0, in); - return true; + return shape_is_known(in); } // Shape inference function for det functions in linalg @@ -421,6 +422,7 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), onum + 2); const mxnet::TShape& in = (*in_attrs)[0]; + if (!ndim_is_known(in)) return false; const int ndim(in.ndim()); CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; @@ -435,7 +437,7 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs, } SHAPE_ASSIGN_CHECK(*out_attrs, onum, in); /* LU */ SHAPE_ASSIGN_CHECK(*out_attrs, onum + 1, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ - return true; + return shape_is_known(in); } // Type inference function for det functions in linalg @@ -444,16 +446,17 @@ inline bool DetType(const nnvm::NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { using namespace mshadow; - CHECK_EQ(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "Input must have specified type"; - - out_type->clear(); + CHECK_EQ(in_type->size(), 1); + CHECK_EQ(out_type->size(), onum + 2); + const int dtype = (*in_type)[0]; + if (dtype == -1) return false; + CHECK(dtype == kFloat32 || dtype == kFloat64) + << "This operation only supports 32-bit and 64-bit floating point"; for (int i = 0; i < onum; ++i) { - out_type->push_back(dtype); /* sign or det or logdet */ + TYPE_ASSIGN_CHECK(*out_type, i, dtype); /* sign or det or logdet */ } - out_type->push_back(dtype); /* LU */ - out_type->push_back(mshadow::kInt32); /* pivot */ + TYPE_ASSIGN_CHECK(*out_type, onum, dtype); /* LU */ + TYPE_ASSIGN_CHECK(*out_type, onum + 1, kInt32); /* pivot */ return true; } From 02bb95ff313709a36063403a2020ca1eb5781e19 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 18:52:45 +0800 Subject: [PATCH 25/30] update comments and docs --- src/operator/tensor/la_op-inl.h | 12 +++++++----- src/operator/tensor/la_op.cc | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 729d30975d95..408e5b3f9907 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -464,9 +464,7 @@ struct inverse { } }; -// partial pivoting LU decomposition: A = PLU, so det(A) = det(P)det(L)det(U) -// det(P) depends on number of row changes in P, det(L) = 1, det(U) = prod(diag(U)) -// this kernel computes sign(det(A)), log(abs(det(A))) +// this kernel computes sign(det(A)), log(abs(det(A))) from LU decomposition struct SignedLogDet { template MSHADOW_XINLINE static void Map(int i, int N, int* pivot, @@ -487,8 +485,12 @@ struct SignedLogDet { } }; -// det = det(A), LU and pivot store the LU decomposition output which will be -// used in computing gradient +// det = det(A), the computation method is based on partial pivoting LU decomposition: +// A = PLU, so det(A) = det(P) * det(L) * det(U), +// det(P) depends on number of row changes in P +// det(L) = 1 since L has unit diagnal elemements +// det(U) = prod(diag(U)) +// LU and pivot store the LU decomposition output which will be used in computing gradient struct det { template static void op(const Tensor& A, const Tensor& det, diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index dcb0728337a6..2c322846c42e 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -905,11 +905,11 @@ for all inputs (batch mode). Examples:: - Single matrix inversion + Single matrix inverse A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]] - Batch matrix inversion + Batch matrix inverse A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], From e3e489d42ffc77ec19c3666f385bc154bf8a334b Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 19:25:27 +0800 Subject: [PATCH 26/30] update det helper function --- src/operator/linalg.h | 19 ++++++++++--------- src/operator/linalg_impl.h | 8 ++++---- src/operator/tensor/la_op-inl.h | 6 +++--- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index c5ce8160e871..e532e2b91cb9 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -246,16 +246,17 @@ void linalg_batch_inverse(const Tensor& A, //////////////////////////////// DET //////////////////////////////////////////// -// CPU/GPU-versions of helper functions to compute matrix determinant -// Compute matrix inversion with LU and pivot using temp workspace, -// the result stores back to LU +// CPU/GPU-versions of helper functions used in matrix determinant operators + +// Helper function in determinant backward computation: compute matrix inverse +// from LU and pivot using temp workspace, the result is stored back to LU template -void linalg_batch_det_helper(const Tensor& LU, - const Tensor& pivot, - const Tensor& det, - const Tensor& temp, - const DType zero_det, - const mxnet::OpContext& ctx); +void linalg_batch_det_backward_helper(const Tensor& LU, + const Tensor& pivot, + const Tensor& det, + const Tensor& temp, + const DType zero_det, + const mxnet::OpContext& ctx); #include "linalg_impl.h" diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index f6c8fd42d33a..8353868db6fc 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1502,11 +1502,11 @@ LINALG_GPU_BATCH_INVERSE(gpu, double) //////////////////////////////// DET //////////////////////////////////////////// -// CPU/GPU-versions of helper functions to compute matrix determinant +// CPU/GPU-versions of helper functions used in matrix determinant operators #define LINALG_CPU_BATCH_DET_HELPER(xpu, DType) \ template<> inline \ -void linalg_batch_det_helper(const Tensor& LU, \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ @@ -1531,7 +1531,7 @@ LINALG_CPU_BATCH_DET_HELPER(cpu, double) #define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ template<> inline \ -void linalg_batch_det_helper(const Tensor& LU, \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ @@ -1546,7 +1546,7 @@ void linalg_batch_det_helper(const Tensor& LU, \ #define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ template<> inline \ -void linalg_batch_det_helper(const Tensor& LU, \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ const Tensor& pivot, \ const Tensor& det, \ const Tensor& temp, \ diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 408e5b3f9907..c795c2a3f375 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -930,7 +930,7 @@ struct det_backward { using namespace mshadow::expr; using namespace mxnet_op; // compute inverse(A) and stores it to LU - linalg_batch_det_helper(LU, pivot, det, dA, DType(0), ctx); + linalg_batch_det_backward_helper(LU, pivot, det, dA, DType(0), ctx); const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); @@ -957,7 +957,7 @@ struct logdet_backward { using namespace mshadow::expr; using namespace mxnet_op; // compute inverse(A) and stores it to LU - linalg_batch_det_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx); + linalg_batch_det_backward_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx); const_cast&>(dA) = broadcast_to(reshape(dlogdet, \ Shape3(logdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); @@ -986,7 +986,7 @@ struct slogdet_backward { using namespace mshadow::expr; using namespace mxnet_op; // compute inverse(A) and stores it to LU - linalg_batch_det_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); + linalg_batch_det_backward_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); const_cast&>(dA) = broadcast_to(reshape(dlogabsdet, \ Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); From 8aa6e3ba2d6d0c1b1de7515001f317601a27ae69 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 19:39:38 +0800 Subject: [PATCH 27/30] remove logdet --- src/operator/tensor/la_op-inl.h | 46 ------------------------ src/operator/tensor/la_op.cc | 48 -------------------------- src/operator/tensor/la_op.cu | 6 ---- tests/python/unittest/test_operator.py | 5 --- 4 files changed, 105 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index c795c2a3f375..de27187bca9a 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -510,25 +510,6 @@ struct det { } }; -// logdet = log(det(A)) -struct logdet { - template - static void op(const Tensor& A, const Tensor& logdet, - const Tensor& LU, const Tensor& pivot, - const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - Stream *s = ctx.get_stream(); - Tensor sign = ctx.requested[0] - .get_space_typed(logdet.shape_, s); - Copy(LU, A, s); - linalg_batch_getrf(LU, pivot, false, s); - using namespace mxnet_op; - using namespace mshadow::expr; - Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, - LU.dptr_, sign.dptr_, logdet.dptr_); - const_cast&>(logdet) = F(sign) + logdet; - } -}; - // sign = sign(det(A)) // logabsdet = log(abs(det(A))) struct slogdet { @@ -941,33 +922,6 @@ struct det_backward { } }; -// Backward of logdet(A) is derived from Jacobi's formula. -// The closed form solution is pretty easy when A is invertible. -// For non-invertible A, grad is not backwarded now. -// TODO(arcadiaphy) add implementation for non-invertible case -struct logdet_backward { - template - static void op(const Tensor& dlogdet, - const Tensor& logdet, - const Tensor& LU, - const Tensor& pivot, - const Tensor& dA, - const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - using namespace mshadow; - using namespace mshadow::expr; - using namespace mxnet_op; - // compute inverse(A) and stores it to LU - linalg_batch_det_backward_helper(LU, pivot, logdet, dA, DType(-INFINITY), ctx); - const_cast&>(dA) = broadcast_to(reshape(dlogdet, \ - Shape3(logdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ - transpose(LU, Shape3(0, 2, 1)); - Stream *s = ctx.get_stream(); - // stop grad for zero det temporarily - Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ - dA.dptr_, logdet.dptr_, DType(-INFINITY)); - } -}; - // Backward of slogdet(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. // For non-invertible A, grad is not backwarded now. diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 2c322846c42e..dade5692800d 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -986,54 +986,6 @@ NNVM_REGISTER_OP(_backward_linalg_det) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpDetBackward); -NNVM_REGISTER_OP(_linalg_logdet) -.add_alias("linalg_logdet") -.describe(R"code(Compute the log determinant of a matrix. -Input is a tensor *A* of dimension *n >= 2*. - -If *n=2*, *A* is a square matrix. We compute: - - *out* = *log(det(A))* - -If *n>2*, *logdet* is performed separately on the trailing two dimensions -for all inputs (batch mode). - -.. note:: The operator supports float32 and float64 data types only. - -Examples:: - - Single matrix inversion - A = [[2., 3.], [1., 4.]] - logdet(A) = [1.609438] - - Batch matrix inversion - A = [[[2., 3.], [1., 4.]], - [[1., 2.], [2., 4.]], - [[1., 2.], [4., 3.]]] - logdet(A) = [1.609438, -inf, nan] -)code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(3) -.set_attr("FListInputNames", [](const NodeAttrs& attrs) - { return std::vector{"A"}; }) -.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { - return 1; }) -.set_attr("FInferShape", DetShape<1>) -.set_attr("FInferType", DetType<1>) -.set_attr("FResourceRequest", [](const NodeAttrs& attrs) - { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", LaOpDetForward) -.set_attr("FGradient", ReduceDetGrad<1>{"_backward_linalg_logdet"}) -.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); - -NNVM_REGISTER_OP(_backward_linalg_logdet) -.set_num_inputs(4) -.set_num_outputs(1) -.set_attr("FResourceRequest", [](const NodeAttrs& attrs) - { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("TIsBackward", true) -.set_attr("FCompute", LaOpDetBackward); - NNVM_REGISTER_OP(_linalg_slogdet) .add_alias("linalg_slogdet") .describe(R"code(Compute the sign and log of the determinant of a matrix. diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 74fc97b75d6b..68c33180e3d5 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -105,12 +105,6 @@ NNVM_REGISTER_OP(_linalg_det) NNVM_REGISTER_OP(_backward_linalg_det) .set_attr("FCompute", LaOpDetBackward); -NNVM_REGISTER_OP(_linalg_logdet) -.set_attr("FCompute", LaOpDetForward); - -NNVM_REGISTER_OP(_backward_linalg_logdet) -.set_attr("FCompute", LaOpDetBackward); - NNVM_REGISTER_OP(_linalg_slogdet) .set_attr("FCompute", LaOpDetForward); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 764f0e59a6d7..3c59524b4fee 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6536,11 +6536,6 @@ def test_laop_6(): test_det = mx.sym.linalg.det(data) check_fw(test_det, [a], [r]) check_grad(test_det, [a]) - # logdet - r = np.log(np.linalg.det(a)) - test_logdet = mx.sym.linalg.logdet(data) - check_fw(test_logdet, [a], [r]) - check_grad(test_logdet, [a]) # test slogdet r1 = np.array([1., 1., 1.]) r2 = np.log(np.abs(np.linalg.det(a))) From 9826d25e0f9448026f211c1e7a4e92ae3d211eb2 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 19:50:08 +0800 Subject: [PATCH 28/30] add no grad when det = 0 --- src/operator/tensor/la_op-inl.h | 10 ++++------ src/operator/tensor/la_op.cc | 4 ++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index de27187bca9a..42d1f4527575 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -884,7 +884,7 @@ struct inverse_backward { } }; -// Here we set grad to zero if det = 0 as a temporary method +// Here we set grad to zero if det = 0 struct StopZeroDetGrad { template MSHADOW_XINLINE static void Map(int i, int grad_step, DType *grad, DType *det, DType zero_det) { @@ -897,8 +897,7 @@ struct StopZeroDetGrad { // Backward of det(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. -// For non-invertible A, grad is not backwarded now. -// TODO(arcadiaphy) add implementation for non-invertible case +// For non-invertible A, grad is not backwarded. struct det_backward { template static void op(const Tensor& ddet, @@ -924,9 +923,8 @@ struct det_backward { // Backward of slogdet(A) is derived from Jacobi's formula. // The closed form solution is pretty easy when A is invertible. -// For non-invertible A, grad is not backwarded now. +// For non-invertible A, grad is not backwarded. // Grad is not properly defined on sign, so it's not backwarded either. -// TODO(arcadiaphy) add implementation for non-invertible case struct slogdet_backward { template static void op(const Tensor& dlogabsdet, @@ -945,7 +943,7 @@ struct slogdet_backward { Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ transpose(LU, Shape3(0, 2, 1)); Stream *s = ctx.get_stream(); - // stop grad for zero det temporarily + // stop grad for zero det Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ dA.dptr_, logabsdet.dptr_, DType(-INFINITY)); } diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index dade5692800d..c426e52a844f 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -952,6 +952,10 @@ If *n>2*, *det* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. +.. note:: There is no gradient backwarded when det(A) == 0 because it's + rarely hit upon in float point computation and the Jacobi's + formula on determinant gradient is not computationally efficient + when A is non-invertible. Examples:: From d768e36078d88081ff613232d5e8860a9228f931 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 20:11:29 +0800 Subject: [PATCH 29/30] update comments and docs --- src/operator/linalg.h | 4 ++-- src/operator/linalg_impl.h | 6 +++--- src/operator/tensor/la_op.cc | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index e532e2b91cb9..8f1eedae03ac 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -228,7 +228,7 @@ void linalg_getri(const Tensor& LU, // Note that this function only implements GPU version with "getriBatched" in cuBLAS. // Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix -// inversion is stored in A. +// inverse is stored in A. template void linalg_batch_getri(const Tensor& A, const Tensor& LU, @@ -237,7 +237,7 @@ void linalg_batch_getri(const Tensor& A, //////////////////////////////// INVERSE //////////////////////////////////////////// -// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri" +// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri" // Note that A = inverse(B) template void linalg_batch_inverse(const Tensor& A, diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 8353868db6fc..958e95555502 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1429,7 +1429,7 @@ LINALG_GPU_BATCH_GETRI(DgetriBatched, double) //////////////////////////////// INVERSE //////////////////////////////////////////// -// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri" +// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri" // Note A = inverse(B) #define LINALG_CPU_BATCH_INVERSE(xpu, DType) \ @@ -1490,7 +1490,7 @@ template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ const mxnet::OpContext& ctx) { \ - LOG(FATAL) << "gpu matrix inversion requires CUDA version >= 8.0!"; \ + LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \ } #endif // CUDA_VERSION >= 8000 @@ -1552,7 +1552,7 @@ void linalg_batch_det_backward_helper(const Tensor& L const Tensor& temp, \ const DType zero_det, \ const mxnet::OpContext& ctx) { \ - LOG(FATAL) << "gpu matrix inversion requires CUDA version >= 8.0!"; \ + LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \ } #endif // CUDA_VERSION >= 8000 diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index c426e52a844f..ce7d1d5de692 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -952,18 +952,18 @@ If *n>2*, *det* is performed separately on the trailing two dimensions for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. -.. note:: There is no gradient backwarded when det(A) == 0 because it's - rarely hit upon in float point computation and the Jacobi's - formula on determinant gradient is not computationally efficient - when A is non-invertible. +.. note:: There is no gradient backwarded when A is non-invertible (which is + equivalent to det(A) = 0) because zero is rarely hit upon in float + point computation and the Jacobi's formula on determinant gradient + is not computationally efficient when A is non-invertible. Examples:: - Single matrix inversion + Single matrix determinant A = [[1., 4.], [2., 3.]] det(A) = [-5.] - Batch matrix inversion + Batch matrix determinant A = [[[1., 4.], [2., 3.]], [[2., 3.], [1., 4.]]] det(A) = [-5., 5.] @@ -1006,16 +1006,18 @@ for all inputs (batch mode). .. note:: The operator supports float32 and float64 data types only. .. note:: The gradient is not properly defined on sign, so the gradient of it is not backwarded. +.. note:: No gradient is backwarded when A is non-invertible. Please see + the docs of operator det for detail. Examples:: - Single matrix inversion + Single matrix signed log determinant A = [[2., 3.], [1., 4.]] sign, logabsdet = slogdet(A) sign = [1.] logabsdet = [1.609438] - Batch matrix inversion + Batch matrix signed log determinant A = [[[2., 3.], [1., 4.]], [[1., 2.], [2., 4.]], [[1., 2.], [4., 3.]]] From 8870555b475f5bb5b1de2c465a837b983b435f89 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 24 May 2019 20:16:07 +0800 Subject: [PATCH 30/30] remove remaining logdet --- docs/api/python/symbol/linalg.md | 1 - python/mxnet/contrib/amp/lists/symbol.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md index f565d7cec28e..5b9afbcbb782 100644 --- a/docs/api/python/symbol/linalg.md +++ b/docs/api/python/symbol/linalg.md @@ -61,7 +61,6 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p maketrian inverse det - logdet slogdet ``` diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 8844774ab28b..0c6a67e83d02 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -434,7 +434,6 @@ '_linalg_extracttrian', '_linalg_inverse', '_linalg_det', - '_linalg_logdet', '_linalg_slogdet', 'linalg_syrk', 'linalg_potrf', @@ -450,7 +449,6 @@ 'linalg_extracttrian', 'linalg_inverse', 'linalg_det', - 'linalg_logdet', 'linalg_slogdet', '_NDArray', '_Native',