Skip to content

Commit

Permalink
support for upper triangular matrices in linalg (apache#12904)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and azai91 committed Dec 1, 2018
1 parent e794333 commit 82855bf
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 242 deletions.
236 changes: 131 additions & 105 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Copyright (c) 2017 by Contributors
* \file la_op-inl.h
* \brief Operators for advanced linear algebra.
* \note See https://arxiv.org/pdf/1710.08717.pdf for details of gradient computations.
*/
#ifndef MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
#define MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
Expand All @@ -32,20 +33,29 @@ namespace op {

using namespace mshadow;

// Helper functions.
struct CopyLowerToUpper {
// Copies lower/upper triangular part to upper/lower, i.e. to the opposite side.
struct CopyTriangularToOppositeSide {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) {
// Below computation works even when we are dealing with a batch of matrices.
const int row((i % matrix_size) / stride), col(i % stride);
if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i];
if (row > col) {
if (to_lower) {
data[i] = data[i + (col - row) * (stride - 1)];
} else {
data[i + (col - row) * (stride - 1)] = data[i];
}
}
}
};
struct ZeroUpper {

// Zero's lower/upper triangular part of a matrix.
struct ZeroTriangular {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data,
bool zero_lower) {
const int row((i % matrix_size) / stride), col(i % stride);
if ( row < col ) data[i] = 0;
if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0;
}
};
struct Scale {
Expand Down Expand Up @@ -103,87 +113,91 @@ struct gemm2 {
}
};

// L = potrf(A).
// B = potrf(A).
struct potrf {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != L.dptr_ ) Copy(L, A, s);
linalg_batch_potrf(L, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
linalg_batch_potrf(B, param.lower, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_, L.stride_, L.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_, B.stride_,
B.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(A, L, s, attrs);
op(A, B, s, attrs);
}
};

// A = potri(L).
// A = potri(B).
struct potri {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != L.dptr_ ) Copy(A, L, s);
linalg_batch_potri(A, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( A.dptr_ != B.dptr_ ) Copy(A, B, s);
linalg_batch_potri(A, param.lower, s);
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_, A.dptr_);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_,
A.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, s, attrs);
op(B, A, s, attrs);
}
};

// B = trsm(L,A)
// C = trsm(A,B)
struct trsm {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s);
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
linalg_batch_trsm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, B, s, attrs);
op(A, B, C, s, attrs);
}
};

// B = trmm(L,A)
// C = trmm(A,B)
struct trmm {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s);
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
linalg_batch_trmm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, Stream<xpu> *s,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
const nnvm::NodeAttrs& attrs) {
if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const OpContext& ctx,
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(L, A, B, s, attrs);
op(A, B, C, s, attrs);
}
};

Expand Down Expand Up @@ -223,8 +237,8 @@ struct syrk {
linalg_batch_syrk(A, B, alpha, beta, tA, s);
// Symmetric B is in lower triangle: Copy to upper
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
B.stride_, B.dptr_);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
B.stride_, B.dptr_, false);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
Expand Down Expand Up @@ -276,8 +290,8 @@ struct gelqf {
Tensor<xpu, 2, DType> QLeft(Qi.dptr_, Shape2(m, m), Qi.stride_, s);
Copy(Li, QLeft, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
Li.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
Li.dptr_, false);
// Call orglq: Input is Qi and part of work. Overwrites Qi by final Q
// matrix (conversion from internal representation)
linalg_orglq(Qi, work, s);
Expand Down Expand Up @@ -395,117 +409,129 @@ struct gemm2_backward {

struct potrf_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dA,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of L = potrf(A).
// dA = 0.5 * L**T * copyLTU(L**T * dL) * L**(-1)
// Backward of B = potrf(A).
// dA = 0.5 * B**T * copyLTU(B**T * dB) * B**(-1)
// Here, copyLTU(M) creates a symmetric matrix from the square matrix M
// by setting the upper triangle to be equal to the lower triangle, leaving
// lower triangle and diagonal unchanged.
if ( dL.dptr_ != dA.dptr_ ) {
Copy(dA, dL, s);
// The function also handles the case when B is upper triangular by appropriate
// transpositions.
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if ( dB.dptr_ != dA.dptr_ ) {
Copy(dA, dB, s);
}
trmm::op(L, dA, DType(1.0), false, true, s);
trmm::op(B, dA, DType(1.0), !param.lower, param.lower, true, s);
using namespace mxnet_op;
Kernel<CopyLowerToUpper, xpu>::Launch
(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_);
trsm::op(L, dA, DType(1.0), false, true, s);
trsm::op(L, dA, DType(0.5), true, false, s);
Kernel<CopyTriangularToOppositeSide, xpu>::Launch
(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_, !param.lower);
trsm::op(B, dA, DType(1.0), false, param.lower, param.lower, s);
trsm::op(B, dA, DType(0.5), true, param.lower, !param.lower, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dA,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dL, L, dA, s, attrs);
op(dB, B, dA, s, attrs);
}
};

struct potri_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of A = potri(L).
// dL = -tril( A * (dA + dA**T) * L**(-T)), where tril() extracts lower triangle
// Backward of A = potri(B).
// dB = -tril( A * (dA + dA**T) * B**(-T)), where tril() extracts lower triangle
// and diagonal. We must not assume that dA is symmetric.
// The function also handles the case when B is upper triangular by appropriate
// transpositions.
// Note: Calling gemm twice here is a bit wasteful, but otherwise the symmetrization
// of dA would require temporary memory.
gemm::op(A, dA, dL, DType(1.), DType(0.), false, false, s);
gemm::op(A, dA, dL, DType(1.), DType(1.), false, true, s);
trsm::op(L, dL, DType(-1.), true, true, s);
const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
if (param.lower) {
gemm::op(A, dA, dB, DType(1.), DType(0.), false, false, s);
gemm::op(A, dA, dB, DType(1.), DType(1.), false, true, s);
} else {
gemm::op(dA, A, dB, DType(1.), DType(0.), false, false, s);
gemm::op(dA, A, dB, DType(1.), DType(1.), true, false, s);
}
trsm::op(B, dB, DType(-1.), param.lower, param.lower, true, s);
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
dL.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, dB.MSize(), dB.size(1)*dB.stride_, dB.stride_,
dB.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dA, L, A, dL, s, attrs);
op(dA, B, A, dB, s, attrs);
}
};

struct trsm_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
// Backward of B = trsm(L,A).
// Backward of C = trsm(A,B).
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
// Compute dB
if ( dB.dptr_ != dC.dptr_ ) Copy(dB, dC, s);
trsm::op(A, dB, DType(param.alpha), param.rightside, param.lower, !param.transpose, s);
// Compute dA
if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s);
// Compute dL
const bool da_left(param.rightside == param.transpose);
DType scale(-1.0/param.alpha);
(da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose, !param.transpose, s)
: gemm::op(B, dA, dL, scale, DType(0), !param.transpose, param.transpose, s));
(da_left ? gemm::op(dB, C, dA, scale, DType(0), param.transpose, !param.transpose, s)
: gemm::op(C, dB, dA, scale, DType(0), !param.transpose, param.transpose, s));
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_);
Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
dA.dptr_, !param.lower);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dB, L, A, B, dL, dA, s, attrs);
op(dC, A, B, C, dA, dB, s, attrs);
}
};

struct trmm_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
const Tensor<xpu, 3, DType>& dA, Stream<xpu>* s,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dB, Stream<xpu>* s,
const nnvm::NodeAttrs& attrs) {
// Backward of B = trmm(L,A).
// Backward of C = trmm(A,B).
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
// Compute dL
// Compute dA
DType scale(param.alpha);
if (param.rightside == param.transpose) {
gemm::op(dB, A, dL, scale, DType(0.), param.transpose, !param.transpose, s);
gemm::op(dC, B, dA, scale, DType(0.), param.transpose, !param.transpose, s);
} else {
gemm::op(A, dB, dL, scale, DType(0.), !param.transpose, param.transpose, s);
gemm::op(B, dC, dA, scale, DType(0.), !param.transpose, param.transpose, s);
}
using namespace mxnet_op;
Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
dL.dptr_);
// Compute dA
if (dA.dptr_ != dB.dptr_) Copy(dA, dB, s);
trmm::op(L, dA, scale, param.rightside, !param.transpose, s);
Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
dA.dptr_, !param.lower);
// Compute dB
if (dB.dptr_ != dC.dptr_) Copy(dB, dC, s);
trmm::op(A, dB, scale, param.rightside, param.lower, !param.transpose, s);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
const Tensor<xpu, 3, DType>& dA, const OpContext& ctx,
static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dB, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dB, L, A, dL, dA, s, attrs);
op(dC, A, B, dA, dB, s, attrs);
}
};

Expand Down Expand Up @@ -586,13 +612,13 @@ struct gelqf_backward {
Tensor<xpu, 3, DType> tempM = ctx.requested[0]
.get_space_typed<xpu, 3, DType>(dL.shape_, s);
Copy(tempM, dL, s);
trmm::op(L, tempM, DType(1.0), false, true, s);
trmm::op(L, tempM, DType(1.0), false, true, true, s);
gemm::op(dA, Q, tempM, DType(-1.0), DType(1.0), false, true, s);
Kernel<CopyLowerToUpper, xpu>::Launch
Kernel<CopyTriangularToOppositeSide, xpu>::Launch
(s, tempM.MSize(), tempM.size(1)*tempM.stride_, tempM.stride_,
tempM.dptr_);
tempM.dptr_, false);
gemm::op(tempM, Q, dA, DType(1.0), DType(1.0), false, false, s);
trsm::op(L, dA, DType(1.0), false, true, s);
trsm::op(L, dA, DType(1.0), false, true, true, s);
}
};

Expand Down
Loading

0 comments on commit 82855bf

Please sign in to comment.