-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@@ -242,6 +249,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { | |||
#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf | |||
#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf | |||
|
|||
#define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \ | |||
inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this case. Maybe add a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -361,6 +382,26 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { | |||
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float) | |||
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double) | |||
|
|||
#define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment that due to row-major and internal column-major, the arguments are flipped and transposed, and m and n are flipped as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
src/operator/linalg.h
Outdated
// CPU/GPU-versions of LAPACK function "gesvd". Please refer to the | ||
// LAPACK documentation for further details. | ||
// Note: | ||
// - V is input and output parameter (overwritten by A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
V is input and output parameter (it overwrites A)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
src/operator/linalg_impl.h
Outdated
Stream<cpu> *s) { \ | ||
check_gesvd(UT, L, V); \ | ||
DType lwork(0); \ | ||
MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, V.size(0), V.size(1), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is wrong. You must have done the workspace query before (calling the other function), so the size of work will be fine to pass for lwork, right? So no need to do the workspace query again here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I called the workspace query again in the syevd implementation, is that there, work consists of two different workspaces. But here, I think you can just use work.size(0) to pass for lwork, and don't have to do the query again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. I have removed the query.
src/operator/linalg_impl.h
Outdated
LINALG_CPU_GESVD(sgesvd, float) | ||
LINALG_CPU_GESVD(dgesvd, double) | ||
|
||
// Mangle temp storage requirements for DType and int into a single |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment, it only applies to syevd. See my comment above, you have a single workspace here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
||
// (UT, L, V) = gesvd(A) [singular value decomposition] | ||
// - V can overwrite A | ||
// - Needs workspace (both DType and int), size of which is determined by a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one workspace (DType) needed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
}); | ||
} | ||
|
||
// Helper for gesvd_backward. See technical report for details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the report cited somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. The public technical report (https://arxiv.org/pdf/1710.08717.pdf) does not include details about svd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you are right. I will have the new report version being uploaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you cite the arxiv paper in the code? The new version with SVD will be uploaded in the next few days, way before this CR will get merged. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Cited.
return 1e-100; | ||
} | ||
|
||
struct GesvdBackHelper_dV { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: dA overwritten by L^-1 dA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added
} | ||
}; | ||
|
||
struct GesvdBackHelper_G1 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: X (square) overwritten by L X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho, X (square) overwritten by X L
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is all transposed, because of this row/col major issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added
|
||
// G1: | ||
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM | ||
Copy(tempMs, dUT, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good! I hope our old code does that as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our old code does that. I learnt from our old code (syevd) about this.
// G1: | ||
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM | ||
Copy(tempMs, dUT, s); | ||
Copy(tempMr, dA, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn, you are right, we need this temp space, because we cannot left-multiply with UT in place.
This could be a real problem for big matrices. It can be circumvented by implementing the left-multiplication with a small square matrix in-place. This is extra work, but would be needed to avoid the large extra temp space
} | ||
for (int i = 0; i < m; ++i) { | ||
elem = DType(0.0); | ||
for (int j = 0; j < n; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed. You have computed this before already, using gemm, just pass in the diagonal.
Also, the function does not need dA and V as input args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, fixed.
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM | ||
Copy(tempMs, dUT, s); | ||
Copy(tempMr, dA, s); | ||
gemm::op(dA, V, tempMs, DType(1.0), DType(0.0), false, true, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this line, you extract the diagonal of this matrix and then pass it to GesvdBackHelper_G2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that tempMs cannot be used to store the extracted diagonal (I have used tempMs to store G1). Do we need extra temp space (of size m) to store the diagonal ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes true, but that is really small extra space. But yes, have to allocate that as well. Instead, we can get rid of tempMr (below)
if (dA.dptr_ != dV.dptr_) { | ||
Copy(dA, dV, s); | ||
} | ||
// From here on, we work on dA only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you assign k = dA.size(0), m = dA.size(1), n = dA.size(2) here, and use them below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
gemm::op(dUT, UT, tempMs, DType(1.0), DType(1.0), true, false, s); | ||
|
||
// G2: | ||
Kernel<GesvdBackHelper_G2, xpu>::Launch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass in the diagonal extracted above, and do not pass dA, V.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
// G3: | ||
gemm::op(tempMs, V, dA, DType(1.0), DType(1.0), false, false, s); | ||
gemm::op(UT, dA, tempMr, DType(1.0), DType(0.0), false, false, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very annoying we need this large temp space, because we don't have in-place left-multiply with square matrix. The drawback here is that for large matrices (large n), this needs a large temp space. It may be worth avoiding that.
const Tensor<xpu, 3, DType>& V, | ||
const Tensor<xpu, 3, DType>& dA, | ||
const Tensor<xpu, 3, DType>& tempMs, | ||
const Tensor<xpu, 3, DType>& tempMr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned below, this large temp space could be avoided by some extra implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate how to
"coding up "in-place" dA <- dot(UT, dA) using a temp space of shape (m, m)"?
The method I have thought about is:
We can write dA as blocks: dA = [dA1, dA2, ..., dAx], where dAi is of shape (m, m) and x = ceil(n / m)
so dot(UT, dA) = [dot(UT, dA1), ..., dot(UT, dAx)], and each dot can be achieved with temp space of shape (m, m)
I don't know whether I have understood correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly, something like you are saying. Essentially you slice the large matrix up into (m, m) blocks, and then you can even just call gemm in a loop. This may be the easiest, in fact. What I mean, you mask out (m, m) blocks of dA and iterate over them, always replacing one block B with dot(UT, B). For that, you need an (m, m) temp space, but that you have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ask if this is still unclear (but your comment is what I have in mind). You can mask out blocks of dA simply by moving the pointer, everything else (stride, etc) remains the same, because the n-axis is the continuous one. The final block will be (m, m2), where m2 <= m, but that should also not be a problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's been clear. See my commet below to see if my implementation is correct.
@@ -0,0 +1,131 @@ | |||
# Licensed to the Apache Software Foundation (ASF) under one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main point to change is extracting the diag in backward, instead of recomputing it.
Another point could be to avoid the large temp space by coding up "in-place"
dA <- dot(UT, dA)
using a temp space of shape (m, m). You can use the one you already have, at this point it is not needed anymore. This could be worth it when this is used for large matrices.
DType gesvd_back_helper_eps(DType* X); | ||
|
||
template<> | ||
MSHADOW_XINLINE float gesvd_back_helper_eps(float* X) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be cleaner to make this constants dependent on values in std::numeric_limits. So something along the lines of
std::numeric_limits::epsilon*10 (or whatever number you need to be on the safe side)
That also avoids doing template specialization here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it out. But I found std::numeric_limits::epsilon() cannot be accessed in Cuda. So I stick with the original implementation for now.
|
||
// G1: | ||
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM | ||
Copy(tempMs, dUT, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be cleaner to fill with zero instead of just copying some arbitrary data. I think there is a specific method in mshadow or elsewhere to fill a tensor with such values (if not, copying like here is likely the best way)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only way I know to fill tensor with zero is to:
tempMs.FlatTo1D() = 0;
tempMr.FlatTo1D() = 0;
But our old code (syevd) sticks to copying. Which do you think is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copy in syevd is here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, so it seems I am to blame for that. But it is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the copying to filling with zeros.
da6d472
to
67191c4
Compare
Just updated the code.
Now the total temp space used is (k, m, m) + (k, m) |
|
||
// G3: | ||
gemm::op(tempM, V, dA, DType(1.0), DType(1.0), false, false, s); | ||
for (int i = 0; i < n; i += m) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I go through the blocks dAi (0:m, i: i + m). Other attrs of tensors like stride_, stream_ are not changed, as you said. The ncols
is the m2 you just mentioned. It is used for the last block.
|
||
// G3: | ||
gemm::op(tempM, V, dA, DType(1.0), DType(1.0), false, false, s); | ||
for (int i = 0; i < n; i += m) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment:
dA <- dot(UT, dA). Loop over (k, m, m) blocks to avoid large temporary memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice. Just one more comment about a missing comment, then this is ready to go, AFAI am concerned. Just make sure your test works both on CPU and GPU.
Thank @mseeger and @asmushetzel for guidance and review. |
* use (m, m) temp space * add technical report citation * add comments for the tricky block matrix multiplication * differentiable svd
* use (m, m) temp space * add technical report citation * add comments for the tricky block matrix multiplication * differentiable svd
Description
Differentiable svd
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
Thank @reminisce and @haojin2 for review and guidance.