diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md index 436bab78c451..5b9afbcbb782 100644 --- a/docs/api/python/symbol/linalg.md +++ b/docs/api/python/symbol/linalg.md @@ -60,6 +60,8 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p extracttrian maketrian inverse + det + slogdet ``` ## API Reference diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index c6cc3d1b1f00..397f4775f8cd 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -433,6 +433,8 @@ '_linalg_maketrian', '_linalg_extracttrian', '_linalg_inverse', + '_linalg_det', + '_linalg_slogdet', 'linalg_syrk', 'linalg_potrf', 'linalg_potri', @@ -446,6 +448,8 @@ 'linalg_maketrian', 'linalg_extracttrian', 'linalg_inverse', + 'linalg_det', + 'linalg_slogdet', '_NDArray', '_Native', '_contrib_count_sketch', diff --git a/src/operator/linalg.h b/src/operator/linalg.h index ee713e5548c0..8f1eedae03ac 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -195,50 +195,68 @@ int linalg_syevd_workspace_query(const Tensor& A, // CPU/GPU-versions of LAPACK function "getrf". Please refer to the // LAPACK documentation for further details. -// Note that this is A = getrf(A), so A is input and output parameter. +// Note: +// - A is input and output parameter (overwritten by LU) +// - Param check_singular is only useful in cpu version. If check_singular is false, +// don't throw error when A is non-invertible matrix. template void linalg_getrf(const Tensor& A, - const Tensor& work, + const Tensor& pivot, + bool check_singular, Stream *s = 0); template void linalg_batch_getrf(const Tensor& A, - const Tensor& work, + const Tensor& pivot, + bool check_singular, Stream *s = 0); //////////////////////////////// GETRI //////////////////////////////////////////// // CPU/GPU-versions of LAPACK function "getri". Please refer to the // LAPACK documentation for further details. -// Note that this is A = getri(A), so A is input and output parameter. +// Note: +// - pivot and LU is the output of getrf(A) +// - LU is also the output parameter (overwritten by inverse(A)) template -void linalg_getri(const Tensor& A, +void linalg_getri(const Tensor& LU, + const Tensor& pivot, \ const Tensor& work, Stream *s = 0); +// Note that this function only implements GPU version with "getriBatched" in cuBLAS. +// Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix +// inverse is stored in A. template void linalg_batch_getri(const Tensor& A, - const Tensor& B, - const Tensor& work, + const Tensor& LU, + const Tensor& pivot, Stream *s = 0); -// This function determines the amount of workspace needed for linalg_getri to operate -// on a batch of matrices which is returned as number of elements of type DType. -template -int linalg_getri_workspace_query(const Tensor& A, - Stream *s = 0); - //////////////////////////////// INVERSE //////////////////////////////////////////// -// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri" +// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri" // Note that A = inverse(B) template void linalg_batch_inverse(const Tensor& A, const Tensor& B, - const Tensor& work, - Stream *s = 0); + const mxnet::OpContext& ctx); + +//////////////////////////////// DET //////////////////////////////////////////// + +// CPU/GPU-versions of helper functions used in matrix determinant operators + +// Helper function in determinant backward computation: compute matrix inverse +// from LU and pivot using temp workspace, the result is stored back to LU +template +void linalg_batch_det_backward_helper(const Tensor& LU, + const Tensor& pivot, + const Tensor& det, + const Tensor& temp, + const DType zero_det, + const mxnet::OpContext& ctx); #include "linalg_impl.h" diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 718e3f9c5aa0..958e95555502 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -1243,23 +1243,40 @@ LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double) #define LINALG_CPU_GETRF(fname, DType) \ template<> inline \ void linalg_getrf(const Tensor& A, \ - const Tensor& work, \ - Stream *s) { \ - int *ipiv = reinterpret_cast(work.dptr_); \ + const Tensor& pivot, \ + bool check_singular, Stream *s) { \ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \ - A.dptr_, A.stride_, ipiv)); \ - CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ + A.dptr_, A.stride_, pivot.dptr_)); \ + CHECK_GE(ret, 0) << #fname << " failed in lapack on cpu."; \ + if (check_singular) { \ + CHECK_EQ(ret, 0) << "the input matrix is non-convertible"; \ + } \ } LINALG_CPU_GETRF(sgetrf, float) LINALG_CPU_GETRF(dgetrf, double) + +#define LINALG_CPU_BATCH_GETRF(fname, DType) \ +template<> inline \ +void linalg_batch_getrf(const Tensor& A, \ + const Tensor& pivot, \ + bool check_singular, \ + Stream *s) { \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_getrf(A[i], pivot[i], check_singular); \ + } \ +} + +LINALG_CPU_BATCH_GETRF(sgetrf, float) +LINALG_CPU_BATCH_GETRF(dgetrf, double) + #ifdef __CUDACC__ // "getrfBatched" and "getriBatched" in cuBLAS must have DType *matrices[] as input // to store the pointers of each batch matrix. This kernel is used to build the // pointer array. -struct set_matrix : public mxnet::op::mxnet_op::tunable { +struct set_matrix { template MSHADOW_XINLINE static void Map(int i, DType **p, DType *m, int step) { p[i] = m + i * step; @@ -1277,23 +1294,22 @@ struct set_matrix : public mxnet::op::mxnet_op::tunable { #define LINALG_GPU_BATCH_GETRF(fname, DType) \ template<> inline \ void linalg_batch_getrf(const Tensor& A, \ - const Tensor& work, \ + const Tensor& pivot, \ + bool check_singular, \ Stream *s) { \ using namespace mxnet; \ using namespace mxnet::op::mxnet_op; \ CHECK_NOTNULL(s); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ DType **A_ptr = static_cast(A_ptr_buf.dptr); \ - const Tensor temp(work.dptr_, A.shape_, s); \ - int *pivot = reinterpret_cast(temp.dptr_ + temp.shape_.Size()); \ - int *info = pivot + A.size(0) * A.size(1); \ - Copy(temp, A, s); \ - Kernel::Launch(s, temp.size(0), \ - A_ptr, temp.dptr_, \ - temp.size(1) * temp.size(2)); \ + Kernel::Launch(s, A.size(0), \ + A_ptr, A.dptr_, \ + A.size(1) * A.size(2)); \ CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), \ - A.size(1), A_ptr, A.size(2), pivot, \ - info, A.size(0))) \ + A.size(1), A_ptr, A.size(2), pivot.dptr_, \ + static_cast(info.dptr), A.size(0))) \ + Storage::Get()->Free(info); \ Storage::Get()->Free(A_ptr_buf); \ } @@ -1302,7 +1318,8 @@ void linalg_batch_getrf(const Tensor& A, \ #define LINALG_GPU_BATCH_GETRF(fname, DType) \ template<> inline \ void linalg_batch_getrf(const Tensor& A, \ - const Tensor& work, \ + const Tensor& pivot, \ + bool check_singular, \ Stream *s) { \ LOG(FATAL) << "batched getrf requires CUDA version >= 8.0!"; \ } @@ -1319,38 +1336,37 @@ LINALG_GPU_BATCH_GETRF(DgetrfBatched, double) // CPU/GPU-versions of LAPACK function "getri" // The input of this function should be col-major for performance. -// Tensor work holds space for ipiv, work in getri #define LINALG_CPU_GETRI(fname, DType) \ template<> inline \ -void linalg_getri(const Tensor& A, \ +void linalg_getri(const Tensor& LU, \ + const Tensor& pivot, \ const Tensor& work, \ Stream *s) { \ - DType wkopt; \ - MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ - A.stride_, nullptr, &wkopt, -1); \ - int lwork(static_cast(wkopt)); \ - int *ipiv = reinterpret_cast(work.dptr_); \ - DType *pwork = reinterpret_cast(ipiv + A.size(0)); \ - int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ - A.stride_, ipiv, pwork, lwork)); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, LU.size(0), LU.dptr_, \ + LU.stride_, pivot.dptr_, work.dptr_, work.size(0))); \ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ } LINALG_CPU_GETRI(sgetri, float) LINALG_CPU_GETRI(dgetri, double) -// Query workspace for the whole batch of matrices.For cpu version, the workspace -// is re-used, so space for only one matrix is enough. +template +int linalg_getri_workspace_query(const Tensor& A, \ + Stream *s) { + LOG(FATAL) << "it only takes float or double Tensor"; + return 0; +} + +// Query workspace for "getri" #define LINALG_CPU_GETRI_WORKSPACE_QUERY(func, DType) \ template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ +int linalg_getri_workspace_query(const Tensor& A, \ Stream *s) { \ - const Tensor& matrix = A[0]; \ DType lwork(0); \ - MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, matrix.size(0), matrix.dptr_, \ - matrix.stride_, nullptr, &lwork, -1); \ - int ipiv = (sizeof(int) * matrix.size(0) + sizeof(DType) - 1) / sizeof(DType); \ - return ipiv + static_cast(lwork); \ + MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \ + A.stride_, nullptr, &lwork, -1); \ + return lwork; \ } + LINALG_CPU_GETRI_WORKSPACE_QUERY(sgetri, float) LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double) @@ -1367,41 +1383,30 @@ LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double) #define LINALG_GPU_BATCH_GETRI(fname, DType) \ template<> inline \ void linalg_batch_getri(const Tensor& A, \ - const Tensor& B, \ - const Tensor& work, \ + const Tensor& LU, \ + const Tensor& pivot, \ Stream *s) { \ using namespace mxnet; \ using namespace mxnet::op::mxnet_op; \ CHECK_NOTNULL(s); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ DType **A_ptr = static_cast(A_ptr_buf.dptr); \ - Storage::Handle B_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ - DType **B_ptr = static_cast(B_ptr_buf.dptr); \ - Tensor temp(work.dptr_, A.shape_, s); \ - int *pivot = reinterpret_cast(temp.dptr_ + temp.shape_.Size()); \ - int *info = pivot + A.size(0) * A.size(1); \ + Storage::Handle LU_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \ + DType **LU_ptr = static_cast(LU_ptr_buf.dptr); \ Kernel::Launch(s, A.size(0), \ A_ptr, A.dptr_, \ A.size(1) * A.size(2)); \ - Kernel::Launch(s, temp.size(0), \ - B_ptr, temp.dptr_, \ - temp.size(1) * temp.size(2)); \ - CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), \ - A.size(1), const_cast(B_ptr), \ - B.size(2), const_cast(pivot), \ - A_ptr, A.size(2), info, A.size(0))) \ + Kernel::Launch(s, LU.size(0), \ + LU_ptr, LU.dptr_, \ + LU.size(1) * LU.size(2)); \ + CUBLAS_CALL(cublas##fname(Stream::GetBlasHandle(s), A.size(1), \ + const_cast(LU_ptr), LU.size(2), \ + const_cast(pivot.dptr_), A_ptr, A.size(2), \ + static_cast(info.dptr), A.size(0))) \ + Storage::Get()->Free(info); \ Storage::Get()->Free(A_ptr_buf); \ - Storage::Get()->Free(B_ptr_buf); \ -} - -#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \ -template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ - Stream *s) { \ - int pivot_size = sizeof(int) * A.size(0) * A.size(1); \ - int info_size = sizeof(int) * A.size(0); \ - int matrix_size = sizeof(DType) * A.shape_.Size(); \ - return (pivot_size + info_size + matrix_size + sizeof(DType) - 1) / sizeof(DType); \ + Storage::Get()->Free(LU_ptr_buf); \ } #else @@ -1409,49 +1414,49 @@ int linalg_getri_workspace_query(const Tensor& A, \ #define LINALG_GPU_BATCH_GETRI(fname, DType) \ template<> inline \ void linalg_batch_getri(const Tensor& A, \ - const Tensor& B, \ - const Tensor& work, \ + const Tensor& LU, \ + const Tensor& pivot, \ Stream *s) { \ LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \ } -#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \ -template<> inline \ -int linalg_getri_workspace_query(const Tensor& A, \ - Stream *s) { \ - LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \ -} - #endif // CUDA_VERSION >= 8000 LINALG_GPU_BATCH_GETRI(SgetriBatched, float) LINALG_GPU_BATCH_GETRI(DgetriBatched, double) -LINALG_GPU_GETRI_WORKSPACE_QUERY(SgetriBatched, float) -LINALG_GPU_GETRI_WORKSPACE_QUERY(DgetriBatched, double) - #endif // __CUDACC__ //////////////////////////////// INVERSE //////////////////////////////////////////// -// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri" +// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri" // Note A = inverse(B) #define LINALG_CPU_BATCH_INVERSE(xpu, DType) \ template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int lwork(linalg_getri_workspace_query(A[0], s)); \ + int workspace_size = (sizeof(int) * A.size(1) + sizeof(DType) * lwork + \ + sizeof(DType) - 1) / sizeof(DType); \ + Tensor workspace = ctx.requested[0].\ + get_space_typed(Shape1(workspace_size), s); \ + const Tensor pivot(reinterpret_cast(workspace.dptr_), \ + Shape1(A.size(1))); \ + const Tensor work(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + Shape1(lwork)); \ if (A.dptr_ != B.dptr_) Copy(A, B, s); \ for (index_t i = 0; i < A.size(0); ++i) { \ - linalg_getrf(A[i], work, s); \ - linalg_getri(A[i], work, s); \ + linalg_getrf(A[i], pivot, true, s); \ + linalg_getri(A[i], pivot, work, s); \ } \ } LINALG_CPU_BATCH_INVERSE(cpu, float) LINALG_CPU_BATCH_INVERSE(cpu, double) + #ifdef __CUDACC__ // GETRF and GETRI only available with cuda8 or higher. @@ -1461,10 +1466,21 @@ LINALG_CPU_BATCH_INVERSE(cpu, double) template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ - linalg_batch_getrf(B, work, s); \ - linalg_batch_getri(A, B, work, s); \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int pivot_size = sizeof(int) * A.size(0) * A.size(1); \ + int matrix_size = sizeof(DType) * A.shape_.Size(); \ + int workspace_size = (pivot_size + matrix_size + \ + sizeof(DType) - 1) / sizeof(DType); \ + Tensor workspace = ctx.requested[0].\ + get_space_typed(Shape1(workspace_size), s); \ + const Tensor pivot(reinterpret_cast(workspace.dptr_), \ + Shape2(A.size(0), A.size(1))); \ + const Tensor LU(reinterpret_cast(pivot.dptr_ + pivot.MSize()), \ + A.shape_); \ + Copy(LU, B, s); \ + linalg_batch_getrf(LU, pivot, true, s); \ + linalg_batch_getri(A, LU, pivot, s); \ } #else @@ -1473,9 +1489,8 @@ void linalg_batch_inverse(const Tensor& A, \ template<> inline \ void linalg_batch_inverse(const Tensor& A, \ const Tensor& B, \ - const Tensor& work, \ - Stream *s) { \ - LOG(FATAL) << "batched getrf and getri requires CUDA version >= 8.0!"; \ + const mxnet::OpContext& ctx) { \ + LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \ } #endif // CUDA_VERSION >= 8000 @@ -1485,4 +1500,64 @@ LINALG_GPU_BATCH_INVERSE(gpu, double) #endif // __CUDACC__ +//////////////////////////////// DET //////////////////////////////////////////// + +// CPU/GPU-versions of helper functions used in matrix determinant operators + +#define LINALG_CPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const DType zero_det, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + int lwork(linalg_getri_workspace_query(LU[0], s)); \ + Tensor work = ctx.requested[0].\ + get_space_typed(Shape1(lwork), s); \ + for (index_t i = 0; i < LU.size(0); ++i) { \ + if (det[i] != zero_det) { \ + linalg_getri(LU[i], pivot[i], work, s); \ + } \ + } \ +} + +LINALG_CPU_BATCH_DET_HELPER(cpu, float) +LINALG_CPU_BATCH_DET_HELPER(cpu, double) + +// GETRF and GETRI only available with cuda8 or higher. +#if CUDA_VERSION >= 8000 + +#define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const DType zero_det, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + linalg_batch_getri(temp, LU, pivot, s); \ + Copy(LU, temp, s); \ +} + +#else + +#define LINALG_GPU_BATCH_DET_HELPER(xpu, DType) \ +template<> inline \ +void linalg_batch_det_backward_helper(const Tensor& LU, \ + const Tensor& pivot, \ + const Tensor& det, \ + const Tensor& temp, \ + const DType zero_det, \ + const mxnet::OpContext& ctx) { \ + LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \ +} + +#endif // CUDA_VERSION >= 8000 + +LINALG_GPU_BATCH_DET_HELPER(gpu, float) +LINALG_GPU_BATCH_DET_HELPER(gpu, double) + #endif // MXNET_OPERATOR_LINALG_IMPL_H_ diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index 4dead87b3dce..42d1f4527575 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -458,14 +458,73 @@ struct inverse { template static void op(const Tensor& B, const Tensor& A, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { - Stream *s = ctx.get_stream(); - // Reserve workspace (size determined by query) - int lwork(linalg_getri_workspace_query(A, s)); - Tensor work = ctx.requested[0] - .get_space_typed(Shape1(lwork), s); // Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose // A even if we are using the col-major version of getrf and getri routines. - linalg_batch_inverse(A, B, work, s); + linalg_batch_inverse(A, B, ctx); + } +}; + +// this kernel computes sign(det(A)), log(abs(det(A))) from LU decomposition +struct SignedLogDet { + template + MSHADOW_XINLINE static void Map(int i, int N, int* pivot, + DType *LU, DType* sign, DType *logdet) { + int changes(0); + DType diag_sign(1); + DType diag_logsum(0); + int *pivot_mat = pivot + i * N; + DType *LU_mat = LU + i * N * N; + for (int j = 0; j < N; ++j) { + changes += (pivot_mat[j] != (j + 1)); + DType diag = LU_mat[j * (N + 1)]; + diag_sign *= ((DType(0) < diag) - (diag < DType(0))); + diag_logsum += std::log(std::abs(diag)); + } + sign[i] = (changes % 2 == 1 ? DType(-1) : DType(1)) * diag_sign; + logdet[i] = diag_logsum; + } +}; + +// det = det(A), the computation method is based on partial pivoting LU decomposition: +// A = PLU, so det(A) = det(P) * det(L) * det(U), +// det(P) depends on number of row changes in P +// det(L) = 1 since L has unit diagnal elemements +// det(U) = prod(diag(U)) +// LU and pivot store the LU decomposition output which will be used in computing gradient +struct det { + template + static void op(const Tensor& A, const Tensor& det, + const Tensor& LU, const Tensor& pivot, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + Tensor sign = ctx.requested[0] + .get_space_typed(det.shape_, s); + Copy(LU, A, s); + // since det(A) = det(trans(A)), so we'll use col-major blas routines here + linalg_batch_getrf(LU, pivot, false, s); + using namespace mxnet_op; + using namespace mshadow::expr; + Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, + LU.dptr_, sign.dptr_, det.dptr_); + const_cast&>(det) = sign * F(det); + } +}; + +// sign = sign(det(A)) +// logabsdet = log(abs(det(A))) +struct slogdet { + template + static void op(const Tensor& A, const Tensor& sign, + const Tensor& logabsdet, const Tensor& LU, + const Tensor& pivot, const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + Stream *s = ctx.get_stream(); + Copy(LU, A, s); + linalg_batch_getrf(LU, pivot, false, s); + using namespace mxnet_op; + using namespace mshadow::expr; + Kernel::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_, + LU.dptr_, sign.dptr_, logabsdet.dptr_); } }; @@ -825,6 +884,71 @@ struct inverse_backward { } }; +// Here we set grad to zero if det = 0 +struct StopZeroDetGrad { + template + MSHADOW_XINLINE static void Map(int i, int grad_step, DType *grad, DType *det, DType zero_det) { + int batch_ind = i / grad_step; + if (det[batch_ind] == zero_det) { + grad[i] = DType(0); + } + } +}; + +// Backward of det(A) is derived from Jacobi's formula. +// The closed form solution is pretty easy when A is invertible. +// For non-invertible A, grad is not backwarded. +struct det_backward { + template + static void op(const Tensor& ddet, + const Tensor& det, + const Tensor& LU, + const Tensor& pivot, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + // compute inverse(A) and stores it to LU + linalg_batch_det_backward_helper(LU, pivot, det, dA, DType(0), ctx); + const_cast&>(dA) = broadcast_to(reshape(det * ddet, \ + Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ + transpose(LU, Shape3(0, 2, 1)); + Stream *s = ctx.get_stream(); + // stop grad for zero det temporarily + Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ + dA.dptr_, det.dptr_, DType(0)); + } +}; + +// Backward of slogdet(A) is derived from Jacobi's formula. +// The closed form solution is pretty easy when A is invertible. +// For non-invertible A, grad is not backwarded. +// Grad is not properly defined on sign, so it's not backwarded either. +struct slogdet_backward { + template + static void op(const Tensor& dlogabsdet, + const Tensor& sign, + const Tensor& logabsdet, + const Tensor& LU, + const Tensor& pivot, + const Tensor& dA, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + // compute inverse(A) and stores it to LU + linalg_batch_det_backward_helper(LU, pivot, logabsdet, dA, DType(-INFINITY), ctx); + const_cast&>(dA) = broadcast_to(reshape(dlogabsdet, \ + Shape3(logabsdet.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \ + transpose(LU, Shape3(0, 2, 1)); + Stream *s = ctx.get_stream(); + // stop grad for zero det + Kernel::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \ + dA.dptr_, logabsdet.dptr_, DType(-INFINITY)); + } +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 2fa1fd3a1cb2..ce7d1d5de692 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -73,14 +73,14 @@ Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedu Examples:: - // Single matrix multiply-add + Single matrix multiply-add A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] C = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] gemm(A, B, C, transpose_b=True, alpha=2.0, beta=10.0) = [[14.0, 14.0, 14.0], [14.0, 14.0, 14.0]] - // Batch matrix multiply-add + Batch matrix multiply-add A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] C = [[[10.0]], [[0.01]]] @@ -149,13 +149,13 @@ Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedu Examples:: - // Single matrix multiply + Single matrix multiply A = [[1.0, 1.0], [1.0, 1.0]] B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]] gemm2(A, B, transpose_b=True, alpha=2.0) = [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]] - // Batch matrix multiply + Batch matrix multiply A = [[[1.0, 1.0]], [[0.1, 0.1]]] B = [[[1.0, 1.0]], [[0.1, 0.1]]] gemm2(A, B, transpose_b=True, alpha=2.0) @@ -204,11 +204,11 @@ If *n>2*, *potrf* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix factorization + Single matrix factorization A = [[4.0, 1.0], [1.0, 4.25]] potrf(A) = [[2.0, 0], [0.5, 2.0]] - // Batch matrix factorization + Batch matrix factorization A = [[[4.0, 1.0], [1.0, 4.25]], [[16.0, 4.0], [4.0, 17.0]]] potrf(A) = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] )code" ADD_FILELINE) @@ -264,11 +264,11 @@ If *n>2*, *potri* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix inverse + Single matrix inverse A = [[2.0, 0], [0.5, 2.0]] potri(A) = [[0.26563, -0.0625], [-0.0625, 0.25]] - // Batch matrix inverse + Batch matrix inverse A = [[[2.0, 0], [0.5, 2.0]], [[4.0, 0], [1.0, 4.0]]] potri(A) = [[[0.26563, -0.0625], [-0.0625, 0.25]], [[0.06641, -0.01562], [-0.01562, 0,0625]]] @@ -320,12 +320,12 @@ If *n>2*, *trmm* is performed separately on the trailing two dimensions for all Examples:: - // Single triangular matrix multiply + Single triangular matrix multiply A = [[1.0, 0], [1.0, 1.0]] B = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] trmm(A, B, alpha=2.0) = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] - // Batch triangular matrix multiply + Batch triangular matrix multiply A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]] trmm(A, B, alpha=2.0) = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], @@ -382,12 +382,12 @@ If *n>2*, *trsm* is performed separately on the trailing two dimensions for all Examples:: - // Single matrix solve + Single matrix solve A = [[1.0, 0], [1.0, 1.0]] B = [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]] trsm(A, B, alpha=0.5) = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - // Batch matrix solve + Batch matrix solve A = [[[1.0, 0], [1.0, 1.0]], [[1.0, 0], [1.0, 1.0]]] B = [[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]], [[4.0, 4.0, 4.0], [8.0, 8.0, 8.0]]] @@ -435,11 +435,11 @@ inputs (batch mode). Examples:: - // Single matrix reduction + Single matrix reduction A = [[1.0, 1.0], [1.0, 7.0]] sumlogdiag(A) = [1.9459] - // Batch matrix reduction + Batch matrix reduction A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]] sumlogdiag(A) = [1.9459, 3.9318] )code" ADD_FILELINE) @@ -476,7 +476,7 @@ If *n>2*, then *A* represents a batch of square matrices on the trailing two dim Examples:: - // Single matrix diagonal extraction + Single matrix diagonal extraction A = [[1.0, 2.0], [3.0, 4.0]] @@ -484,7 +484,7 @@ Examples:: extractdiag(A, 1) = [2.0] - // Batch matrix diagonal extraction + Batch matrix diagonal extraction A = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], @@ -526,7 +526,7 @@ If *n>1*, then *A* represents a batch of diagonals of square matrices. The batch Examples:: - // Single diagonal matrix construction + Single diagonal matrix construction A = [1.0, 2.0] makediag(A) = [[1.0, 0.0], @@ -536,7 +536,7 @@ Examples:: [0.0, 0.0, 2.0], [0.0, 0.0, 0.0]] - // Batch diagonal matrix construction + Batch diagonal matrix construction A = [[1.0, 2.0], [3.0, 4.0]] @@ -585,7 +585,7 @@ The *offset* and *lower* parameters determine the triangle to be extracted: Examples:: - // Single triagonal extraction + Single triagonal extraction A = [[1.0, 2.0], [3.0, 4.0]] @@ -594,7 +594,7 @@ Examples:: extracttrian(A, 1) = [2.0] extracttrian(A, -1) = [3.0] - // Batch triagonal extraction + Batch triagonal extraction A = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], @@ -639,7 +639,7 @@ If *n>1*, then *A* represents a batch of triangular sub-matrices. The batch of c Examples:: - // Single matrix construction + Single matrix construction A = [1.0, 2.0, 3.0] maketrian(A) = [[1.0, 0.0], @@ -655,7 +655,7 @@ Examples:: [1.0, 0.0, 0.0], [2.0, 3.0, 0.0]] - // Batch matrix construction + Batch matrix construction A = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] @@ -714,7 +714,7 @@ inputs (batch mode). Examples:: - // Single matrix multiply + Single matrix multiply A = [[1., 2., 3.], [4., 5., 6.]] syrk(A, alpha=1., transpose=False) = [[14., 32.], @@ -724,7 +724,7 @@ Examples:: [22., 29., 36.], [27., 36., 45.]] - // Batch matrix multiply + Batch matrix multiply A = [[[1., 1.]], [[0.1, 0.1]]] syrk(A, alpha=2., transpose=False) = [[[4.]], [[0.04]]] )code" ADD_FILELINE) @@ -775,7 +775,7 @@ inputs (batch mode). Examples:: - // Single LQ factorization + Single LQ factorization A = [[1., 2., 3.], [4., 5., 6.]] Q, L = gelqf(A) Q = [[-0.26726124, -0.53452248, -0.80178373], @@ -783,7 +783,7 @@ Examples:: L = [[-3.74165739, 0.], [-8.55235974, 1.96396101]] - // Batch LQ factorization + Batch LQ factorization A = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]] Q, L = gelqf(A) @@ -847,14 +847,14 @@ mode). In this case, *U* has *n* dimensions like *A*, and *L* has *n-1* dimensio Examples:: - // Single symmetric eigendecomposition + Single symmetric eigendecomposition A = [[1., 2.], [2., 4.]] U, L = syevd(A) U = [[0.89442719, -0.4472136], [0.4472136, 0.89442719]] L = [0., 5.] - // Batch symmetric eigendecomposition + Batch symmetric eigendecomposition A = [[[1., 2.], [2., 4.]], [[1., 2.], [2., 5.]]] U, L = syevd(A) @@ -905,11 +905,11 @@ for all inputs (batch mode). Examples:: - // Single matrix inversion + Single matrix inverse A = [[1., 4.], [2., 3.]] inverse(A) = [[-0.6, 0.8], [0.4, -0.2]] - // Batch matrix inversion + Batch matrix inverse A = [[[1., 4.], [2., 3.]], [[1., 3.], [2., 4.]]] inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]], @@ -939,5 +939,111 @@ NNVM_REGISTER_OP(_backward_linalg_inverse) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_det) +.add_alias("linalg_det") +.describe(R"code(Compute the determinant of a matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a square matrix. We compute: + + *out* = *det(A)* + +If *n>2*, *det* is performed separately on the trailing two dimensions +for all inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. +.. note:: There is no gradient backwarded when A is non-invertible (which is + equivalent to det(A) = 0) because zero is rarely hit upon in float + point computation and the Jacobi's formula on determinant gradient + is not computationally efficient when A is non-invertible. + +Examples:: + + Single matrix determinant + A = [[1., 4.], [2., 3.]] + det(A) = [-5.] + + Batch matrix determinant + A = [[[1., 4.], [2., 3.]], + [[2., 3.], [1., 4.]]] + det(A) = [-5., 5.] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(3) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; }) +.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { + return 1; }) +.set_attr("FInferShape", DetShape<1>) +.set_attr("FInferType", DetType<1>) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FCompute", LaOpDetForward) +.set_attr("FGradient", ReduceDetGrad<1>{"_backward_linalg_det"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +NNVM_REGISTER_OP(_backward_linalg_det) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpDetBackward); + +NNVM_REGISTER_OP(_linalg_slogdet) +.add_alias("linalg_slogdet") +.describe(R"code(Compute the sign and log of the determinant of a matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, *A* is a square matrix. We compute: + + *sign* = *sign(det(A))* + *logabsdet* = *log(abs(det(A)))* + +If *n>2*, *slogdet* is performed separately on the trailing two dimensions +for all inputs (batch mode). + +.. note:: The operator supports float32 and float64 data types only. +.. note:: The gradient is not properly defined on sign, so the gradient of + it is not backwarded. +.. note:: No gradient is backwarded when A is non-invertible. Please see + the docs of operator det for detail. + +Examples:: + + Single matrix signed log determinant + A = [[2., 3.], [1., 4.]] + sign, logabsdet = slogdet(A) + sign = [1.] + logabsdet = [1.609438] + + Batch matrix signed log determinant + A = [[[2., 3.], [1., 4.]], + [[1., 2.], [2., 4.]], + [[1., 2.], [4., 3.]]] + sign, logabsdet = slogdet(A) + sign = [1., 0., -1.] + logabsdet = [1.609438, -inf, 1.609438] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(4) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; }) +.set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { + return 2; }) +.set_attr("FInferShape", DetShape<2>) +.set_attr("FInferType", DetType<2>) +.set_attr("FCompute", LaOpDetForward) +.set_attr("FGradient", ReduceDetGrad<2>{"_backward_linalg_slogdet"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix"); + +NNVM_REGISTER_OP(_backward_linalg_slogdet) +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpDetBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 3ef714e00c18..68c33180e3d5 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -99,6 +99,18 @@ NNVM_REGISTER_OP(_linalg_inverse) NNVM_REGISTER_OP(_backward_linalg_inverse) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_det) +.set_attr("FCompute", LaOpDetForward); + +NNVM_REGISTER_OP(_backward_linalg_det) +.set_attr("FCompute", LaOpDetBackward); + +NNVM_REGISTER_OP(_linalg_slogdet) +.set_attr("FCompute", LaOpDetForward); + +NNVM_REGISTER_OP(_backward_linalg_slogdet) +.set_attr("FCompute", LaOpDetBackward); + #if MXNET_USE_CUSOLVER == 1 NNVM_REGISTER_OP(_linalg_potrf) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 5b0c7e3562dc..e024693e3819 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -406,10 +406,57 @@ inline bool InverseShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const mxnet::TShape& in = (*in_attrs)[0]; + if (!ndim_is_known(in)) return false; const int ndim(in.ndim()); CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; SHAPE_ASSIGN_CHECK(*out_attrs, 0, in); + return shape_is_known(in); +} + +// Shape inference function for det functions in linalg +template +inline bool DetShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), onum + 2); + const mxnet::TShape& in = (*in_attrs)[0]; + if (!ndim_is_known(in)) return false; + const int ndim(in.ndim()); + CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2"; + CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal"; + mxnet::TShape out; + if (ndim == 2) { + out = mxnet::TShape(1, 1); + } else { + out = mxnet::TShape(in.begin(), in.end() - 2); + } + for (int i = 0; i < onum; ++i) { + SHAPE_ASSIGN_CHECK(*out_attrs, i, out); /* sign or det or logdet */ + } + SHAPE_ASSIGN_CHECK(*out_attrs, onum, in); /* LU */ + SHAPE_ASSIGN_CHECK(*out_attrs, onum + 1, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */ + return shape_is_known(in); +} + +// Type inference function for det functions in linalg +template +inline bool DetType(const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + using namespace mshadow; + CHECK_EQ(in_type->size(), 1); + CHECK_EQ(out_type->size(), onum + 2); + const int dtype = (*in_type)[0]; + if (dtype == -1) return false; + CHECK(dtype == kFloat32 || dtype == kFloat64) + << "This operation only supports 32-bit and 64-bit floating point"; + for (int i = 0; i < onum; ++i) { + TYPE_ASSIGN_CHECK(*out_type, i, dtype); /* sign or det or logdet */ + } + TYPE_ASSIGN_CHECK(*out_type, onum, dtype); /* LU */ + TYPE_ASSIGN_CHECK(*out_type, onum + 1, kInt32); /* pivot */ return true; } @@ -753,6 +800,139 @@ void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs, }); } + +template +struct LaOpDetForwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + CHECK(false) << "no specialized LaOpDetForward defined for template parameters"; + } +}; +template +struct LaOpDetForwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + outputs[2].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpDetForwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + outputs[2].FlatToKD(s), + outputs[3].FlatToKD(s), ctx, attrs); + } +}; +template +void LaOpDetForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), onum + 2); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + LaOpDetForwardCaller::op(inputs, outputs, attrs, ctx); + }); +} + +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + CHECK(false) << "no specialized LaOpDetBackward defined for template parameters"; + } +}; +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + inputs[1].FlatToKD(s), + inputs[2].FlatToKD(s), + inputs[3].FlatToKD(s), + outputs[0].FlatToKD(s), ctx, attrs); + } +}; +template +struct LaOpDetBackwardCaller { + static void op(const std::vector& inputs, + const std::vector& outputs, + const nnvm::NodeAttrs& attrs, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + laop::op(inputs[0].FlatToKD(s), + inputs[1].FlatToKD(s), + inputs[2].FlatToKD(s), + inputs[3].FlatToKD(s), + inputs[4].FlatToKD(s), + outputs[0].FlatToKD(s), ctx, attrs); + } +}; +template +void LaOpDetBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), onum + 3); + CHECK_EQ(outputs.size(), 1); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + std::vector tspace(outputs); + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + tspace[i].dptr_ = ctx.requested[0] + .get_space_typed(Shape1(outputs[i].Size()), s).dptr_; + } + } + LaOpDetBackwardCaller::op(inputs, tspace, attrs, ctx); + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + Tensor out = outputs[i].FlatTo1D(s); + out += tspace[i].FlatTo1D(s); + } + } + }); +} + +// Only transfer ddet and outputs to gradient +template +struct ReduceDetGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector heads; + heads.push_back(ograds[onum - 1]); + uint32_t n_out = n->num_outputs(); + for (uint32_t i = 0; i < n_out; ++i) { + heads.emplace_back(nnvm::NodeEntry{n, i, 0}); + } + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1a33c8e78cb3..58d1e42b4c84 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6856,8 +6856,14 @@ def test_laop_6(): check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw, atol=atol_bw, dtype=dtype) - a = np.sqrt(np.arange(4 * 4)).reshape(4, 4) + ## det(I + dot(v, v.T)) = 1 + dot(v.T, v) >= 1, so it's always invertible; + ## det is away from zero, so the value of logdet is stable + v = np.random.random(4) + a = np.eye(4) + np.outer(v, v) a = np.tile(a, (3, 1, 1)) + permute_mat = np.eye(4)[[1, 0, 2, 3]] + + # test matrix inverse r = np.eye(4) r = np.tile(r, (3, 1, 1)) test_inverse = mx.sym.linalg.inverse(data) @@ -6865,6 +6871,21 @@ def test_laop_6(): check_fw(test_eye, [a], [r]) check_grad(test_inverse, [a]) + # test matrix determinant + # det + r = np.linalg.det(a) + test_det = mx.sym.linalg.det(data) + check_fw(test_det, [a], [r]) + check_grad(test_det, [a]) + # test slogdet + r1 = np.array([1., 1., 1.]) + r2 = np.log(np.abs(np.linalg.det(a))) + test_sign, test_logabsdet = mx.sym.linalg.slogdet(data) + check_fw(test_sign, [a], [r1]) + check_fw(test_sign, [np.dot(a, permute_mat)], [-r1]) + check_fw(test_logabsdet, [a], [r2]) + check_grad(test_logabsdet, [a]) + @with_seed() def test_stack(): for _ in range(100):