Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add matrix determinant operator in linalg #15007

Merged
merged 30 commits into from
Aug 7, 2019
Merged

Conversation

arcadiaphy
Copy link
Member

@arcadiaphy arcadiaphy commented May 20, 2019

Description

The second PR on linalg enhancement. Three operators regarding matrix determinant is added: det, logdet, slogdet.

Something worth mentioning in this PR:

  1. Matrix inversion code is refactored a little bit to remove duplicated workspace query in getri, and make it more readable and easy to use in det operators.
  2. Log(det(x)) in logdet follows the log operator, the grad is passed backwards even when det(x) < 0.
  3. Sign in slogdet follows the sign operator, the grad on it is ignored in backward pass since it's not properly defined.
  4. The determinant calculation uses LU factorization with partial pivoting.
  5. The grad of determinant is derived from Jacobi's formula, which has a pretty friendly closed form solution for numerical computing when input matrix A is invertible. The non-invertible case is not easy to implement since it involves adjugate matrix. In tensorflow, this case is ignored; while pytorch uses SVD to compute the grad. In this PR, it's left for future work, and now as a temporary method, no grad is passed backwards when det = 0. My inclination is to re-use LU instead of SVD for non-invertible case since it's already calculated.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@pinaraws
Copy link

@mxnet-label-bot add[Operator, pr-work-in-progress]

@marcoabreu marcoabreu added Operator pr-work-in-progress PR is still work in progress labels May 20, 2019
@arcadiaphy
Copy link
Member Author

@eric-haibin-lin @reminisce @apeforest @anirudh2290
Is anyone free to review my enhancement on linalg package? Also my last PR #14963 on matrix inversion is not fully reviewed, so please help review that part too. Thanks!

@arcadiaphy arcadiaphy added the pr-awaiting-review PR is waiting for code review label May 21, 2019
@szha
Copy link
Member

szha commented May 21, 2019

cc @asmushetzel

Copy link
Contributor

@reminisce reminisce left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! A few comments on improving shape/type inference logic.


Examples::

// Single matrix inversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think // is needed here since this is not a piece of c++ code.
inversion -> determinant

A = [[1., 4.], [2., 3.]]
det(A) = [-5.]

// Batch matrix inversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), onum + 2);
const mxnet::TShape& in = (*in_attrs)[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this line of code here:

if (!ndim_is_known(in)) return false;

}
SHAPE_ASSIGN_CHECK(*out_attrs, onum, in); /* LU */
SHAPE_ASSIGN_CHECK(*out_attrs, onum + 1, mxnet::TShape(in.begin(), in.end() - 1)); /* pivot */
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with

return shape_is_known(in);

int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "Input must have specified type";

out_type->clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use TYPE_ASSIGN_CHECK for every output type assignment.

using namespace mshadow;
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "Input must have specified type";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to

if (dtype == -1) return false;

std::vector<int>* out_type) {
using namespace mshadow;
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int dtype

CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "Input must have specified type";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only fp32/64 are supported, please add one line here for checking whether dtype is equal to kFloat32/kFloat64.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All updated.

@arcadiaphy
Copy link
Member Author

Regarding this point:

The grad of determinant is derived from Jacobi's formula, which has a pretty friendly closed form solution for numerical computing when input matrix A is invertible. The non-invertible case is not easy to implement since it involves adjugate matrix. In tensorflow, this case is ignored; while pytorch uses SVD to compute the grad. In this PR, it's left for future work, and now as a temporary method, no grad is passed backwards when det = 0. My inclination is to re-use LU instead of SVD for non-invertible case since it's already calculated.

I've looked into pytorch code more carefully, I think their implementation is wrong. A simple example to show this:

In [1]: import torch

In [2]: x = torch.autograd.Variable(torch.tensor([[1., 2.], [2., 4.]]), requires_grad=True)

In [3]: y = x.det()

In [4]: y.backward(torch.ones_like(y))

In [5]: x.grad
Out[5]:
tensor([[0., 0.],
        [0., 0.]])

Since in actual computing it's very hard to hit upon det == 0 using float, is it really necessary to implement non-invertible case?

Also I haven't thought of a good method now, any suggestions are welcomed.

@asmushetzel
Copy link
Contributor

@mseeger

@@ -939,5 +939,153 @@ NNVM_REGISTER_OP(_backward_linalg_inverse)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, inverse_backward>);

NNVM_REGISTER_OP(_linalg_det)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that we should add all types of variants like "det", "logdet" etc. Conceptually there is no end to that (why not have logdetminus1?). It also doesn't save any compute time to provide as the overwhelming compute is in the core determinant computation and not in a subsequent log or exp. And we end up with a lot of duplicate redundant code.

So I would propose to provide one generic method, which is likely signed logdet() and leave it to the user to apply additional operators when she/he needs some variant. You can even make the return of "sign" optional by a parameter.

Copy link
Member Author

@arcadiaphy arcadiaphy May 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. If we only want to keep one method, then the slogdet is the best choice.

Since the community is now working on providing numpy experience in mxnet, how about we follow the numpy.linalg package, and keep det and slogdet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello,

I developed the linalg operators with @asmushetzel.

I feel it is not very elegant to introduce a lot of MXNet operators, which are essentially just done by sticking existing operators together. It would be a lot cleaner just to provide Python functions for this (using F in {mx.nd, mx.sym} as first arg).

Note we ourselves made this mistake with sumlogdiag, which is ugly and should not be there, really (we could be excused, since diag wasn't there back then).

For example, if you really want logdet, you get it by a LQ decomp (linalg.geqlf), followed by log.sum over the diagonal of L, we have ops for all of this. In fact, you probably want log(abs(det)), because the determinant could be negative. You could return the sign as well.

I don't understand also why a det(.) op is needed, given there is logdet(.) with sign. You can get one from the other. Also, det(.) for large matrices is prone to over or underflow anyway.

It is also somewhat dangerous to offer such operators, because they end up recomputing the underlying factorizations every time. For example, to evaluate the log likelihood of a Gaussian, you need logdet and backsubstitution. You compute the Cholesky decomp. once, then use it twice. With your logdet, I cannot do that. It computes something inside, but does not return it.

Finally, I also find it dangerous to offer inverse. The matrix inverse is almost never needed, but people who lack numerical maths knowledge use it. They should not, it leads to bad code. They should use matrix factorizations, like Cholesky or LQ (i.e. QR), or SVD.

So, if you really want to do something useful, think about a set of Python functions for derived operators. An example:

def linalg_logdet_from_chol(F, lmat):
return F.sum(F.log(F.abs(F.diag(lmat))))

If you really want:

def linalg_logdet(F, amat):
return linalg_logdet_from_chol(F, F.potrf(amat))


Examples::

Single matrix inversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this and similar other comments will be updated (here refering to "matrix inversion")


Examples::

Single matrix inversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Matrix inversion"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

}
};

// partial pivoting LU decomposition: A = PLU, so det(A) = det(P)det(L)det(U)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment. This is the final computation of logdet and sign based on an existing LU decomposition, not the LU decomposition itself.

Copy link
Member Author

@arcadiaphy arcadiaphy May 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to explain the computation method in det, I'll change the comment to make it more clear.

}

// Type inference function for linalg_inverse
inline bool InverseType(const nnvm::NodeAttrs& attrs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a combined diff with #14963? Both PRs should be either one joined PR or be clearly separated in terms of code changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's separate PRs, I'll remove the changes not related to this PR.

@@ -920,7 +920,7 @@ Examples::
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
{ return std::vector<std::string>{"A"}; } )
.set_attr<mxnet::FInferShape>("FInferShape", InverseShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", InverseType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an "InverseType"?

Copy link
Member Author

@arcadiaphy arcadiaphy May 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add type checks in layer setup. Now many size or type checks in linalg operators are carried out in forward process, which is not the common practice in mxnet.

But this change is not related to this PR, I'll remove it.

using namespace mshadow::expr;
Kernel<SignedLogDet, xpu>::Launch(s, pivot.size(0), pivot.size(1), pivot.dptr_,
LU.dptr_, sign.dptr_, logdet.dptr_);
const_cast<Tensor<xpu, 1, DType>&>(logdet) = F<mshadow_op::log>(sign) + logdet;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make too much sense: If sign==1 then log(1)=0 (so no effect) and if sign is "-1" then this crashes anyway. See my other comments regarding to provide only one variant of det() and leave any other one to the user (by applying additional operators).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove logdet.

@@ -825,6 +901,100 @@ struct inverse_backward {
}
};

// Here we set grad to zero if det = 0 as a temporary method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not add "temporary" methods. It's either the way we want to handle the case of det()==0 or not. Fully ok to define that the gradient is assumed to be zero in this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure how to handle det = 0 case. If it's ok to use zero here, I'll change the comment and add a note in operator docs.

Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \
transpose(LU, Shape3(0, 2, 1));
Stream<xpu> *s = ctx.get_stream<xpu>();
// stop grad for zero det temporarily
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"temporarily"

// Compute matrix inversion with LU and pivot using temp workspace,
// the result stores back to LU
template<typename xpu, typename DType>
void linalg_batch_det_helper(const Tensor<xpu, 3, DType>& LU,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment and naming. It looks to me that this helper is not needed to compute the determinant, it is just used for the backward pass and it assumes that the determinant is actually computed and available.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@arcadiaphy
Copy link
Member Author

arcadiaphy commented May 24, 2019

Updates:

  1. Only operators det and slogdet are kept in this PR to match numpy.
  2. Zero gradient when det = 0 is intended behavior now and explained in operator docs.

There is no gradient backwarded when A is non-invertible (which is equivalent to det(A) = 0) because zero is rarely hit upon in float point computation and the Jacobi's formula on determinant gradient is not computationally efficient when A is non-invertible.

@mseeger
Copy link
Contributor

mseeger commented May 27, 2019

Hello,

I developed the linalg operators with @asmushetzel.

I feel it is not very elegant to introduce a lot of MXNet operators, which are essentially just done by sticking existing operators together. It would be a lot cleaner just to provide Python functions for this (using F in {mx.nd, mx.sym} as first arg).

Note we ourselves made this mistake with sumlogdiag, which is ugly and should not be there, really (we could be excused, since diag wasn't there back then).

For example, if you really want logdet, you get it by a LQ decomp (linalg.geqlf), followed by log.sum over the diagonal of L, we have ops for all of this. In fact, you probably want log(abs(det)), because the determinant could be negative. You could return the sign as well.

I don't understand also why a det(.) op is needed, given there is logdet(.) with sign. You can get one from the other. Also, det(.) for large matrices is prone to over or underflow anyway.

It is also somewhat dangerous to offer such operators, because they end up recomputing the underlying factorizations every time. For example, to evaluate the log likelihood of a Gaussian, you need logdet and backsubstitution. You compute the Cholesky decomp. once, then use it twice. With your logdet, I cannot do that. It computes something inside, but does not return it.

Finally, I also find it dangerous to offer inverse. The matrix inverse is almost never needed, but people who lack numerical maths knowledge use it. They should not, it leads to bad code. They should use matrix factorizations, like Cholesky or LQ (i.e. QR), or SVD.

So, if you really want to do something useful, think about a set of Python functions for derived operators. An example:

def linalg_logdet_from_chol(F, lmat):
return F.sum(F.log(F.abs(F.diag(lmat))))

If you really want:

def linalg_logdet(F, amat):
return linalg_logdet_from_chol(F, F.potrf(amat))

@arcadiaphy
Copy link
Member Author

arcadiaphy commented May 28, 2019

@mseeger I think many ordinary users are not experts in linear algebra, so providing a set of basic linear blas/lapack routines and some advanced operators at the same time is the best way for all levels of users. My reason to improve linalg package comes from my daily usage and some users' feature requests. I agree that using basic operators is more efficient, but sometimes easy usage is more important and it doesn't matter if the computation is not fast enough as long as I've quickly finished the implementation.

Operators like inverse, det and slogdet are implemented in numpy, tensorflow and pytorch, it's really annoying there are no mxnet equivalents when "translating" implementations from other platforms. (Actually, these related PRs comes from my project which involves translating thin plate spline algorithm.)

Also for using python to create some derived operators, actually it's what I want to do in the first place but is restricted by the weak supports of mxnet on registering backward function. Using fine-grained operators to mimic high level operation is all right for forward pass, but the backward pass will be terrible because the combined simple gradient computation is split into backward passes of basic operators. Now the only way to overwrite backward function in mxnet is Custom operator, which is not elegant. In pytorch, it's more user-and-developer-friendly to do this, and the following is the backward of matrix inverse registered in yaml file:

- name: inverse(Tensor self)
  self: -at::matmul(result.transpose(-2, -1), at::matmul(grad, result.transpose(-2, -1)))

I think a good linalg package is like scipy.linalg, in which all routines of blas and lapack are exposed, and there are also a lot of high level functions.

@piyushghai
Copy link
Contributor

@arcadiaphy What's the path forward on this PR ? :)

@arcadiaphy
Copy link
Member Author

@piyushghai Two paths forward for linalg:

  1. Use low level operators to implement determinant in python, but it's not very easy to override backward function.

  2. Providing low level and high level operators in c++ at the same time.

This PR is for the 2nd option. If it's merged, the third part of linalg improvement is to add svd and pinverse.

@asmushetzel
Copy link
Contributor

Tend to agree with @arcadiaphy concerning adding these operators though conceptually you can do the same more efficient by explicitly handling matrix factorization etc on your own. Providing a more seemless experience when porting models from one framework to the other is important for adoption. No objections from my side concerning merging this PR.

@arcadiaphy
Copy link
Member Author

@szha How about we merge this PR and move forward?

@karan6181
Copy link
Contributor

@szha Is this PR good to go for merge? Thanks!

@piyushghai
Copy link
Contributor

@szha Gentle ping for review.

@reminisce
Copy link
Contributor

We would like to merge this PR so that we can push forward adding more ops to match np.linalg. If no objection, I will merge this PR tomorrow. Thanks for all the comments and contribution.

@reminisce reminisce merged commit 45db8ea into apache:master Aug 7, 2019
anirudhacharya pushed a commit to anirudhacharya/mxnet that referenced this pull request Aug 20, 2019
* add backbone

* cpu forward det

* refactor for gpu forward det

* fix

* register gpu det forward

* add gpu det backward

* register gpu det backward

* fix

* add logdet slogdet backward

* stop grad for zero det

* fix

* fix

* reduce grad transfer

* fix docs

* update comments

* fix docs

* fix lint

* add test

* update docs

* add operator

* update test

* trigger CI

* remove slash

* update comments and docs

* update det helper function

* update operator check

* remove logdet

* add no grad when det = 0

* update comments and docs

* remove remaining logdet
@arcadiaphy arcadiaphy deleted the pr_det branch December 5, 2019 10:36
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants