From a92f91d5f305d3889c67b1d4e850cc3df13ddd63 Mon Sep 17 00:00:00 2001 From: Asmus Hetzel Date: Fri, 22 Mar 2019 00:19:15 +0100 Subject: [PATCH] added extraction/generation of diagonal and triangonal matrices to linalg --- docs/api/python/ndarray/linalg.md | 6 +- docs/api/python/symbol/linalg.md | 6 +- src/operator/tensor/la_op-inl.h | 94 ++++++++++ src/operator/tensor/la_op.cc | 231 +++++++++++++++++++++++++ src/operator/tensor/la_op.cu | 24 +++ src/operator/tensor/la_op.h | 68 ++++++++ tests/python/unittest/test_operator.py | 45 +++++ 7 files changed, 472 insertions(+), 2 deletions(-) diff --git a/docs/api/python/ndarray/linalg.md b/docs/api/python/ndarray/linalg.md index 41436c3ba2d1..b73d9680a874 100644 --- a/docs/api/python/ndarray/linalg.md +++ b/docs/api/python/ndarray/linalg.md @@ -51,10 +51,14 @@ In the rest of this document, we list routines provided by the `ndarray.linalg` potri trmm trsm - sumlogdiag syrk gelqf syevd + sumlogdiag + extractdiag + makediag + extracttrian + maketrian ``` ## API Reference diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md index f1891e29f896..5b467b501247 100644 --- a/docs/api/python/symbol/linalg.md +++ b/docs/api/python/symbol/linalg.md @@ -51,10 +51,14 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p potri trmm trsm - sumlogdiag syrk gelqf syevd + sumlogdiag + extractdiag + makediag + extracttrian + maketrian ``` ## API Reference diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index e89a0824a948..bda8137675a8 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -229,6 +229,100 @@ struct sumlogdiag { } }; +template +struct CopyDiag { + template + MSHADOW_XINLINE static void Map(int i, int k, int n, DType* A, DType* B) { + // Index of the matrix from which the diagonal should be extracted. + const int matrix(i / (n-abs(k))); + // Index of the diagonal element that should be extracted. + const int index(i % (n-abs(k))); + // row/col that must be looked up. + const int row(index-(k < 0 ? k : 0)), col(index+(k > 0 ? k :0)); + if (forward) { + B[i] = A[(matrix*n+row)*n+col]; + } else { + B[(matrix*n+row)*n+col] = A[i]; + } + } +}; + +struct copydiag { + // Extracts diagonal from matrix. + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LaDiagParam& param = nnvm::get(attrs.parsed); + Kernel, xpu>::Launch(s, B.MSize(), param.offset, A.size(1), A.dptr_, B.dptr_); + } + // Sets diagonal in matrix. + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LaDiagParam& param = nnvm::get(attrs.parsed); + Kernel::Launch(s, B.MSize(), B.dptr_); + Kernel, xpu>::Launch(s, A.MSize(), param.offset, B.size(1), A.dptr_, B.dptr_); + } +}; + +template +struct CopyTrian { + template + MSHADOW_XINLINE static void Map(int i, bool lower, int k, int n, DType* A, DType* B) { + // Matrix that this index belongs to. + const int matrix(i/(n*n)); + // Row/Col that this index represents. + int row((i/n)%n), col(i%n); + if ((k > 0) || ((k == 0) && !lower)) { + // When working on upper triangle we switch to transposed coordinates for indexing. + int tmp(row); + row = col; + col = tmp; + } + // Actual row inside the lower triangular matrix after offset adjustment. + row -= abs(k); + if (row >= col) { + // Index in the 1-dimensional array that holds the values of the triangle. + const int index((row*(row+1))/2+col); + // Total number of entries in the triangle. + const int m(((n-abs(k))*(n-abs(k)+1))/2); + if (forward) { + B[m*matrix+index] = A[i]; + } else { + B[i] = A[m*matrix+index]; + } + } + } +}; + +struct copytrian { + // Extracts triangle from matrix. + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LaTrianParam& param = nnvm::get(attrs.parsed); + Kernel, xpu>::Launch(s, A.MSize(), param.lower, param.offset, + A.size(1), A.dptr_, B.dptr_); + } + // Sets triangle in matrix. + template + static void op(const Tensor& A, const Tensor& B, + const OpContext& ctx, const nnvm::NodeAttrs& attrs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LaTrianParam& param = nnvm::get(attrs.parsed); + Kernel::Launch(s, B.MSize(), B.dptr_); + Kernel, xpu>::Launch(s, B.MSize(), param.lower, param.offset, + B.size(1), A.dptr_, B.dptr_); + } +}; + // B = syrk(A) struct syrk { template diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 12cea91f5800..d6e64c4f78cd 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -33,6 +33,8 @@ DMLC_REGISTER_PARAMETER(LaMatrixMacParam); DMLC_REGISTER_PARAMETER(LaMatrixMultParam); DMLC_REGISTER_PARAMETER(LaCholeskyParam); DMLC_REGISTER_PARAMETER(LaTriangMatrixMultParam); +DMLC_REGISTER_PARAMETER(LaDiagParam); +DMLC_REGISTER_PARAMETER(LaTrianParam); DMLC_REGISTER_PARAMETER(LaSyrkParam); NNVM_REGISTER_OP(_linalg_gemm) @@ -461,6 +463,235 @@ NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) .set_attr("TIsBackward", true) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_extractdiag) +.add_alias("linalg_extractdiag") +.describe(R"code(Extracts the diagonal entries of a square matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, then *A* represents a single square matrix which diagonal elements get extracted as a 1-dimensional tensor. + +If *n>2*, then *A* represents a batch of square matrices on the trailing two dimensions. The extracted diagonals are returned as an *n-1*-dimensional tensor. + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix diagonal extraction + A = [[1.0, 2.0], + [3.0, 4.0]] + + extractdiag(A) = [1.0, 4.0] + + extractdiag(A, 1) = [2.0] + + // Batch matrix diagonal extraction + A = [[[1.0, 2.0], + [3.0, 4.0]], + [[5.0, 6.0], + [7.0, 8.0]]] + + extractdiag(A) = [[1.0, 4.0], + [5.0, 8.0]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaDiagTrianShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_linalg_extractdiag"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrices") +.add_arguments(LaDiagParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_linalg_extractdiag) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_makediag) +.add_alias("linalg_makediag") +.describe(R"code(Constructs a square matrix with the input as diagonal. +Input is a tensor *A* of dimension *n >= 1*. + +If *n=1*, then *A* represents the diagonal entries of a single square matrix. This matrix will be returned as a 2-dimensional tensor. +If *n>1*, then *A* represents a batch of diagonals of square matrices. The batch of diagonal matrices will be returned as an *n+1*-dimensional tensor. + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single diagonal matrix construction + A = [1.0, 2.0] + + makediag(A) = [[1.0, 0.0], + [0.0, 2.0]] + + makediag(A, 1) = [[0.0, 1.0, 0.0], + [0.0, 0.0, 2.0], + [0.0, 0.0, 0.0]] + + // Batch diagonal matrix construction + A = [[1.0, 2.0], + [3.0, 4.0]] + + makediag(A) = [[[1.0, 0.0], + [0.0, 2.0]], + [[3.0, 0.0], + [0.0, 4.0]]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaDiagTrianShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_linalg_makediag"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of diagonal entries") +.add_arguments(LaDiagParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_linalg_makediag) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_extracttrian) +.add_alias("linalg_extracttrian") +.describe(R"code(Extracts a triangular sub-matrix from a square matrix. +Input is a tensor *A* of dimension *n >= 2*. + +If *n=2*, then *A* represents a single square matrix from which a triangular sub-matrix is extracted as a 1-dimensional tensor. + +If *n>2*, then *A* represents a batch of square matrices on the trailing two dimensions. The extracted triangular sub-matrices are returned as an *n-1*-dimensional tensor. + +The *offset* and *lower* parameters determine the triangle to be extracted: + +- When *offset = 0* either the lower or upper triangle with respect to the main diagonal is extracted depending on the value of parameter *lower*. +- When *offset = k > 0* the upper triangle with respect to the k-th diagonal above the main diagonal is extracted. +- When *offset = k < 0* the lower triangle with respect to the k-th diagonal below the main diagonal is extracted. + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single triagonal extraction + A = [[1.0, 2.0], + [3.0, 4.0]] + + extracttrian(A) = [1.0, 3.0, 4.0] + extracttrian(A, lower=False) = [1.0, 2.0, 4.0] + extracttrian(A, 1) = [2.0] + extracttrian(A, -1) = [3.0] + + // Batch triagonal extraction + A = [[[1.0, 2.0], + [3.0, 4.0]], + [[5.0, 6.0], + [7.0, 8.0]]] + + extracttrian(A) = [[1.0, 3.0, 4.0], + [5.0, 7.0, 8.0]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaDiagTrianShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_linalg_extracttrian"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrices") +.add_arguments(LaTrianParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_linalg_extracttrian) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_maketrian) +.add_alias("linalg_maketrian") +.describe(R"code(Constructs a square matrix with the input representing a specific triangular sub-matrix. +This is basically the inverse of *linalg.extracttrian*. Input is a tensor *A* of dimension *n >= 1*. + +If *n=1*, then *A* represents the entries of a triangular matrix which is lower triangular if *offset<0* or *offset=0*, *lower=true*. The resulting matrix is derived by first constructing the square +matrix with the entries outside the triangle set to zero and then adding *offset*-times an additional +diagonal with zero entries to the square matrix. + +If *n>1*, then *A* represents a batch of triangular sub-matrices. The batch of corresponding square matrices is returned as an *n+1*-dimensional tensor. + +.. note:: The operator supports float32 and float64 data types only. + +Examples:: + + // Single matrix construction + A = [1.0, 2.0, 3.0] + + maketrian(A) = [[1.0, 0.0], + [2.0, 3.0]] + + maketrian(A, lower=false) = [[1.0, 2.0], + [0.0, 3.0]] + + maketrian(A, offset=1) = [[0.0, 1.0, 2.0], + [0.0, 0.0, 3.0], + [0.0, 0.0, 0.0]] + maketrian(A, offset=-1) = [[0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 3.0, 0.0]] + + // Batch matrix construction + A = [[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]] + + maketrian(A) = [[[1.0, 0.0], + [2.0, 3.0]], + [[4.0, 0.0], + [5.0, 6.0]]] + + maketrian(A, offset=1) = [[[0.0, 1.0, 2.0], + [0.0, 0.0, 3.0], + [0.0, 0.0, 0.0]], + [[0.0, 4.0, 5.0], + [0.0, 0.0, 6.0], + [0.0, 0.0, 0.0]]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) + { return std::vector{"A"}; } ) +.set_attr("FInferShape", LaDiagTrianShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", LaOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_linalg_maketrian"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of triangular matrices stored as vectors") +.add_arguments(LaTrianParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_linalg_maketrian) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs) + { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackward); + NNVM_REGISTER_OP(_linalg_syrk) .add_alias("linalg_syrk") .describe(R"code(Multiplication of matrix with its transpose. diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index 29a48466313c..ec310fe76fcd 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -63,6 +63,30 @@ NNVM_REGISTER_OP(_linalg_sumlogdiag) NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) .set_attr("FCompute", LaOpBackward); +NNVM_REGISTER_OP(_linalg_extractdiag) +.set_attr("FCompute", LaOpForward); + +NNVM_REGISTER_OP(_backward_linalg_extractdiag) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_makediag) +.set_attr("FCompute", LaOpForward); + +NNVM_REGISTER_OP(_backward_linalg_makediag) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_extracttrian) +.set_attr("FCompute", LaOpForward); + +NNVM_REGISTER_OP(_backward_linalg_extracttrian) +.set_attr("FCompute", LaOpBackward); + +NNVM_REGISTER_OP(_linalg_maketrian) +.set_attr("FCompute", LaOpForward); + +NNVM_REGISTER_OP(_backward_linalg_maketrian) +.set_attr("FCompute", LaOpBackward); + NNVM_REGISTER_OP(_linalg_potri) .set_attr("FCompute", LaOpForward); diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 5e18e0ef5a25..ba996fc5fe2e 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -129,6 +129,33 @@ struct LaSyrkParam : public dmlc::Parameter { } }; +// Parameters for diag extraction/creation. +struct LaDiagParam : public dmlc::Parameter { + int offset; + DMLC_DECLARE_PARAMETER(LaDiagParam) { + DMLC_DECLARE_FIELD(offset) + .set_default(0) + .describe("Offset of the diagonal versus the main diagonal. 0 corresponds to the main " + "diagonal, a negative/positive value to diagonals below/above the main diagonal."); + } +}; + +// Parameters for trian extraction/creation. +struct LaTrianParam : public dmlc::Parameter { + int offset; + bool lower; + DMLC_DECLARE_PARAMETER(LaTrianParam) { + DMLC_DECLARE_FIELD(offset) + .set_default(0) + .describe("Offset of the diagonal versus the main diagonal. 0 corresponds to the main " + "diagonal, a negative/positive value to diagonals below/above the main diagonal."); + DMLC_DECLARE_FIELD(lower) + .set_default(true) + .describe("Refer to the lower triangular matrix if lower=true, refer to the upper otherwise." + " Only relevant when offset=0"); + } +}; + // Common function for shape inference for matrix mult and matrix mac. inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, @@ -262,6 +289,47 @@ inline bool LaReduceShape(const nnvm::NodeAttrs& attrs, return true; } +template +inline bool LaDiagTrianShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + const int ndim((*in_attrs)[0].ndim()); + // Only infer in forward direction + if (ndim == 0) { + return false; + } + const int offset = (diag ? nnvm::get(attrs.parsed).offset + : nnvm::get(attrs.parsed).offset); + std::vector oshape(extract ? ndim-1 : ndim+1); + for (int i = 0; i < ndim-1; ++i) { + oshape[i] = (*in_attrs)[0][i]; + } + if (extract) { + CHECK_GE(ndim, 2) + << "Input operand must be a tensor of matrices"; + CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[0][ndim-1]) + << "Input operand must be a tensor of square matrices"; + const int n((*in_attrs)[0][ndim-1]-abs(offset)); + CHECK_GT(n, 0) + << "Illegal offset " << offset << " for diag/trian extraction of matrix with dimension " + << ndim; + oshape[ndim-2] = (diag ? n : (n*(n+1))/2); + } else if (diag) { + oshape[ndim] = oshape[ndim-1] = (*in_attrs)[0][ndim-1]+abs(offset); + } else { + const int n((*in_attrs)[0][ndim-1]); + const int m(std::floor(0.5+(std::sqrt(8*n+1)-1.0)*0.5)); + CHECK_EQ((m*(m+1))/2, n) + << "Input tensor of maketrian has an invalid dimension for the last axis."; + oshape[ndim] = oshape[ndim-1] = m+abs(offset); + } + mxnet::TShape tshape(oshape.begin(), oshape.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); + return true; +} + // Shape inference function for linalg_syrk inline bool LaSyrkShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_attrs, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 845ae113c218..0647721d17e7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6128,6 +6128,51 @@ def test_laop_4(): #print('float32') check_fw(test_syevd, [a_np], [u_np, l_np], np.float32) +def test_laop_5(): + # tests for diagonal and triangular matrix extraction and generation + data = mx.symbol.Variable('data') + # test complete range of small matrices to cover corner cases + for n in range(1, 10): + # test batched and non-batched processing + for b in range(3): + shape = (n, n) if b == 0 else (b, n, n) + data_in = np.random.uniform(1, 10, shape) + # test all legal offsets of the diagonal + for offs in range(1-n, n): + # test extraction of diagonal + test_diag = mx.sym.linalg.extractdiag(data, offset=offs) + res_diag = np.diagonal(data_in, offset=offs) if b==0 else np.diagonal(data_in, axis1=1, axis2=2, offset=offs) + check_symbolic_forward(test_diag, [data_in], [res_diag]) + check_numeric_gradient(test_diag, [data_in]) + # test generation of diagonal matrix + test_diag2 = mx.sym.linalg.makediag(data, offset=offs) + res_diag2 = None + if b == 0: + res_diag2 = np.diagflat(res_diag, k=offs) + else: + for i in range(b): + res = np.reshape(np.diagflat(res_diag[i], k=offs), (1, n, n)) + res_diag2 = res if res_diag2 is None else np.concatenate((res_diag2, res), axis=0) + check_symbolic_forward(test_diag2, [res_diag], [res_diag2]) + check_numeric_gradient(test_diag2, [res_diag]) + # check both settings for parameter "lower" in case of zero offset + lower_vals = [True] if offs != 0 else [True, False] + for lower in lower_vals: + # test extraction of triangle by doing a full roundtrip as the intermediate extracted + # triangle has different orderings than numpy. + test_trian = mx.sym.linalg.extracttrian(data, offset=offs, lower=lower) + test_trian = mx.sym.linalg.maketrian(test_trian, offset=offs, lower=lower) + extracts_lower = (offs < 0) or ((offs == 0) and lower) + res_trian = None + if b == 0: + res_trian = np.tril(data_in, offs) if extracts_lower else np.triu(data_in, offs) + else: + for i in range(b): + res = np.tril(data_in[i], offs) if extracts_lower else np.triu(data_in[i], offs) + res = np.reshape(res, (1, n, n)) + res_trian = res if res_trian is None else np.concatenate((res_trian, res), axis=0) + check_symbolic_forward(test_trian, [data_in], [res_trian]) + check_numeric_gradient(test_trian, [data_in]) @with_seed() def test_stack():