Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Differentiable svd #15795

Merged
merged 4 commits into from
Sep 19, 2019
Merged

[Numpy] Differentiable svd #15795

merged 4 commits into from
Sep 19, 2019

Conversation

hzfan
Copy link
Contributor

@hzfan hzfan commented Aug 8, 2019

Description

Differentiable svd

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • add np.linalg.svd
  • add forward and backward tests

Comments

Thank @reminisce and @haojin2 for review and guidance.

@hzfan hzfan requested a review from szha as a code owner August 8, 2019 06:50
@hzfan hzfan changed the title Differentiable svd [Numpy] Differentiable svd Aug 8, 2019
@reminisce reminisce added the Numpy label Aug 8, 2019
@@ -242,6 +249,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf

#define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this case. Maybe add a comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LAPACK_gesvd function interface differs in signature from the MXNET_LAPACK-signature and have to be wrapped (as is stated here). So this is basically a wrapper of LAPACK_gesvd.

I added some comments about how to use the LAPACK_gesvd. Its official document can be found here.

@@ -361,6 +382,26 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)

#define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment that due to row-major and internal column-major, the arguments are flipped and transposed, and m and n are flipped as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

// CPU/GPU-versions of LAPACK function "gesvd". Please refer to the
// LAPACK documentation for further details.
// Note:
// - V is input and output parameter (overwritten by A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

V is input and output parameter (it overwrites A)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Stream<cpu> *s) { \
check_gesvd(UT, L, V); \
DType lwork(0); \
MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, V.size(0), V.size(1), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong. You must have done the workspace query before (calling the other function), so the size of work will be fine to pass for lwork, right? So no need to do the workspace query again here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I called the workspace query again in the syevd implementation, is that there, work consists of two different workspaces. But here, I think you can just use work.size(0) to pass for lwork, and don't have to do the query again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. I have removed the query.

LINALG_CPU_GESVD(sgesvd, float)
LINALG_CPU_GESVD(dgesvd, double)

// Mangle temp storage requirements for DType and int into a single
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment, it only applies to syevd. See my comment above, you have a single workspace here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


// (UT, L, V) = gesvd(A) [singular value decomposition]
// - V can overwrite A
// - Needs workspace (both DType and int), size of which is determined by a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one workspace (DType) needed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

});
}

// Helper for gesvd_backward. See technical report for details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the report cited somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. The public technical report (https://arxiv.org/pdf/1710.08717.pdf) does not include details about svd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you are right. I will have the new report version being uploaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you cite the arxiv paper in the code? The new version with SVD will be uploaded in the next few days, way before this CR will get merged. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Cited.

return 1e-100;
}

struct GesvdBackHelper_dV {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: dA overwritten by L^-1 dA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added

}
};

struct GesvdBackHelper_G1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: X (square) overwritten by L X

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho, X (square) overwritten by X L

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is all transposed, because of this row/col major issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added


// G1:
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
Copy(tempMs, dUT, s);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good! I hope our old code does that as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our old code does that. I learnt from our old code (syevd) about this.

// G1:
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
Copy(tempMs, dUT, s);
Copy(tempMr, dA, s);
Copy link
Contributor

@mseeger mseeger Aug 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, you are right, we need this temp space, because we cannot left-multiply with UT in place.

This could be a real problem for big matrices. It can be circumvented by implementing the left-multiplication with a small square matrix in-place. This is extra work, but would be needed to avoid the large extra temp space

}
for (int i = 0; i < m; ++i) {
elem = DType(0.0);
for (int j = 0; j < n; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed. You have computed this before already, using gemm, just pass in the diagonal.
Also, the function does not need dA and V as input args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed.

// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
Copy(tempMs, dUT, s);
Copy(tempMr, dA, s);
gemm::op(dA, V, tempMs, DType(1.0), DType(0.0), false, true, s);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this line, you extract the diagonal of this matrix and then pass it to GesvdBackHelper_G2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that tempMs cannot be used to store the extracted diagonal (I have used tempMs to store G1). Do we need extra temp space (of size m) to store the diagonal ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes true, but that is really small extra space. But yes, have to allocate that as well. Instead, we can get rid of tempMr (below)

if (dA.dptr_ != dV.dptr_) {
Copy(dA, dV, s);
}
// From here on, we work on dA only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you assign k = dA.size(0), m = dA.size(1), n = dA.size(2) here, and use them below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

gemm::op(dUT, UT, tempMs, DType(1.0), DType(1.0), true, false, s);

// G2:
Kernel<GesvdBackHelper_G2, xpu>::Launch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass in the diagonal extracted above, and do not pass dA, V.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


// G3:
gemm::op(tempMs, V, dA, DType(1.0), DType(1.0), false, false, s);
gemm::op(UT, dA, tempMr, DType(1.0), DType(0.0), false, false, s);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very annoying we need this large temp space, because we don't have in-place left-multiply with square matrix. The drawback here is that for large matrices (large n), this needs a large temp space. It may be worth avoiding that.

const Tensor<xpu, 3, DType>& V,
const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& tempMs,
const Tensor<xpu, 3, DType>& tempMr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned below, this large temp space could be avoided by some extra implementation

Copy link
Contributor Author

@hzfan hzfan Aug 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate how to
"coding up "in-place" dA <- dot(UT, dA) using a temp space of shape (m, m)"?
The method I have thought about is:
We can write dA as blocks: dA = [dA1, dA2, ..., dAx], where dAi is of shape (m, m) and x = ceil(n / m)
so dot(UT, dA) = [dot(UT, dA1), ..., dot(UT, dAx)], and each dot can be achieved with temp space of shape (m, m)
I don't know whether I have understood correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, something like you are saying. Essentially you slice the large matrix up into (m, m) blocks, and then you can even just call gemm in a loop. This may be the easiest, in fact. What I mean, you mask out (m, m) blocks of dA and iterate over them, always replacing one block B with dot(UT, B). For that, you need an (m, m) temp space, but that you have.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ask if this is still unclear (but your comment is what I have in mind). You can mask out blocks of dA simply by moving the pointer, everything else (stride, etc) remains the same, because the n-axis is the continuous one. The final block will be (m, m2), where m2 <= m, but that should also not be a problem.

Copy link
Contributor Author

@hzfan hzfan Aug 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been clear. See my commet below to see if my implementation is correct.

@@ -0,0 +1,131 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice test!

Copy link
Contributor

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main point to change is extracting the diag in backward, instead of recomputing it.

Another point could be to avoid the large temp space by coding up "in-place"
dA <- dot(UT, dA)
using a temp space of shape (m, m). You can use the one you already have, at this point it is not needed anymore. This could be worth it when this is used for large matrices.

DType gesvd_back_helper_eps(DType* X);

template<>
MSHADOW_XINLINE float gesvd_back_helper_eps(float* X) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cleaner to make this constants dependent on values in std::numeric_limits. So something along the lines of
std::numeric_limits::epsilon*10 (or whatever number you need to be on the safe side)
That also avoids doing template specialization here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it out. But I found std::numeric_limits::epsilon() cannot be accessed in Cuda. So I stick with the original implementation for now.


// G1:
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
Copy(tempMs, dUT, s);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be cleaner to fill with zero instead of just copying some arbitrary data. I think there is a specific method in mshadow or elsewhere to fill a tensor with such values (if not, copying like here is likely the best way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way I know to fill tensor with zero is to:

    tempMs.FlatTo1D() = 0;
    tempMr.FlatTo1D() = 0;

But our old code (syevd) sticks to copying. Which do you think is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy in syevd is here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, so it seems I am to blame for that. But it is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the copying to filling with zeros.

@hzfan hzfan force-pushed the svd_pr branch 2 times, most recently from da6d472 to 67191c4 Compare August 14, 2019 09:02
@hzfan
Copy link
Contributor Author

hzfan commented Aug 14, 2019

Just updated the code.
The two main points updated:

  • reuse the diagonal results of (L^-1 dV VT) with an additional temp space of shape (m,)
  • avoid tempMr (which is of shape (m, n)) by joining blocks of dots.

Now the total temp space used is (k, m, m) + (k, m)


// G3:
gemm::op(tempM, V, dA, DType(1.0), DType(1.0), false, false, s);
for (int i = 0; i < n; i += m) {
Copy link
Contributor Author

@hzfan hzfan Aug 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I go through the blocks dAi (0:m, i: i + m). Other attrs of tensors like stride_, stream_ are not changed, as you said. The ncols is the m2 you just mentioned. It is used for the last block.


// G3:
gemm::op(tempM, V, dA, DType(1.0), DType(1.0), false, false, s);
for (int i = 0; i < n; i += m) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment:
dA <- dot(UT, dA). Loop over (k, m, m) blocks to avoid large temporary memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added.

Copy link
Contributor

@mseeger mseeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. Just one more comment about a missing comment, then this is ready to go, AFAI am concerned. Just make sure your test works both on CPU and GPU.

@hzfan
Copy link
Contributor Author

hzfan commented Aug 15, 2019

Thank @mseeger and @asmushetzel for guidance and review.

@reminisce reminisce merged commit 6247dc8 into apache:master Sep 19, 2019
drivanov pushed a commit to drivanov/incubator-mxnet that referenced this pull request Sep 26, 2019
* use (m, m) temp space

* add technical report citation

* add comments for the tricky block matrix multiplication

* differentiable svd
larroy pushed a commit to larroy/mxnet that referenced this pull request Sep 28, 2019
* use (m, m) temp space

* add technical report citation

* add comments for the tricky block matrix multiplication

* differentiable svd
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants