[src] Speed up AddDiagMat2 for very thin and tall matrices#2555
[src] Speed up AddDiagMat2 for very thin and tall matrices#2555danpovey merged 1 commit intokaldi-asr:masterfrom
Conversation
d0eddbf to
3566a89
Compare
|
And FYI, this is the error: |
src/cudamatrix/cu-vector.cc
Outdated
| MatrixTransposeType other_trans = (trans == kTrans ? kNoTrans : kTrans); | ||
| this->AddDiagMatMat(alpha, M, trans, | ||
| M, other_trans, beta); | ||
| if (trans == kTrans && M.NumCols() <= 512 || M.NumCols() <= 128) { |
There was a problem hiding this comment.
i'm not sure that the part after the '||' here makes sense.
There was a problem hiding this comment.
|| is to speed up other cases. Removing that part is OK and won't affect your special case.
| int32 dim, MatrixTransposeType trans) { | ||
| BaseFloat time_in_secs = 0.02; | ||
| int32 size = 1024 * 32; | ||
| CuVector<Real> v(std::max(dim, size / dim)); |
There was a problem hiding this comment.
The std::max here probably isn't right, you should be deciding which one to use based on 'trans'.
I think this is why it was failing on the CPU but not GPU version: there was likely a dimension check in the CPU version that is not present in the GPU version (but should be present).
There was a problem hiding this comment.
I saw your fix. Your way is correct.
|
My point is that transposing whenever M.NumCols() <= 128, without asking
what the num-rows is, doesn't make sense. For instance, it would transpose
a 64x32 matrix and also would transpose a 32x64 matrix, which can't be
right because one must be faster than the other.
…On Tue, Jul 17, 2018 at 6:16 PM, Shiyin Kang ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In src/cudamatrix/cu-vector.cc
<#2555 (comment)>:
> @@ -569,8 +569,12 @@ void CuVectorBase<Real>::AddDiagMat2(Real alpha, const CuMatrixBase<Real> &M,
if (CuDevice::Instantiate().Enabled()) {
if (dim_ == 0) return;
MatrixTransposeType other_trans = (trans == kTrans ? kNoTrans : kTrans);
- this->AddDiagMatMat(alpha, M, trans,
- M, other_trans, beta);
+ if (trans == kTrans && M.NumCols() <= 512 || M.NumCols() <= 128) {
|| is to speed up other cases. Removing that part is OK and won't affect
your special case.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#2555 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/ADJVu7YLLAPqOtJqVQSxUO2wNQ-uVl6jks5uHox_gaJpZM4VSx9M>
.
|
Speed(gflops) size old new speedup
CuVector::AddDiagMat2Shapes<double>[trans], (1048576, 32), 1.10 8.63 7.84x
CuVector::AddDiagMat2Shapes<double>[trans], (524288, 64), 2.19 8.64 3.94x
CuVector::AddDiagMat2Shapes<double>[trans], (262144, 128), 4.38 8.71 1.99x
CuVector::AddDiagMat2Shapes<double>[trans], (131072, 256), 8.64 8.56 0.99x
CuVector::AddDiagMat2Shapes<double>[trans], (65536, 512), 15.71 15.72 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (32768, 1024), 26.11 26.11 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (16384, 2048), 31.51 31.51 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (8192, 4096), 28.08 28.19 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (4096, 8192), 31.53 31.58 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (2048, 16384), 31.19 31.23 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (1024, 32768), 31.42 31.35 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (512, 65536), 31.47 31.55 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (256, 131072), 31.00 30.89 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (128, 262144), 30.02 30.12 1.00x
CuVector::AddDiagMat2Shapes<double>[trans], (64, 524288), 28.44 28.72 1.01x
CuVector::AddDiagMat2Shapes<double>[trans], (32, 1048576), 24.95 24.95 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (1048576, 32), 1.25 16.44 13.17x
CuVector::AddDiagMat2Shapes<float>[trans], (524288, 64), 2.48 17.19 6.92x
CuVector::AddDiagMat2Shapes<float>[trans], (262144, 128), 4.92 17.08 3.47x
CuVector::AddDiagMat2Shapes<float>[trans], (131072, 256), 9.54 18.33 1.92x
CuVector::AddDiagMat2Shapes<float>[trans], (65536, 512), 17.83 17.92 1.01x
CuVector::AddDiagMat2Shapes<float>[trans], (32768, 1024), 31.46 31.50 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (16384, 2048), 34.40 34.41 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (8192, 4096), 51.63 51.70 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (4096, 8192), 48.73 48.72 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (2048, 16384), 57.52 57.52 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (1024, 32768), 56.36 56.35 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (512, 65536), 55.85 55.87 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (256, 131072), 55.38 55.71 1.01x
CuVector::AddDiagMat2Shapes<float>[trans], (128, 262144), 54.36 54.61 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (64, 524288), 52.81 52.99 1.00x
CuVector::AddDiagMat2Shapes<float>[trans], (32, 1048576), 47.58 48.17 1.01x
Speed(gflops) dim old new speedup
CuVector::AddDiagMat2<double>[no-trans], 32 0.08 0.08 1.00x
CuVector::AddDiagMat2<double>[no-trans], 64 0.32 0.32 1.00x
CuVector::AddDiagMat2<double>[no-trans], 128 1.20 1.22 1.01x
CuVector::AddDiagMat2<double>[no-trans], 256 3.85 3.84 1.00x
CuVector::AddDiagMat2<double>[no-trans], 512 11.38 11.37 1.00x
CuVector::AddDiagMat2<double>[no-trans], 1024 22.07 21.93 0.99x
CuVector::AddDiagMat2<double>[no-trans], 2048 29.53 29.55 1.00x
CuVector::AddDiagMat2<double>[no-trans], 4096 31.20 31.74 1.02x
CuVector::AddDiagMat2<double>[no-trans], 8192 32.78 32.80 1.00x
CuVector::AddDiagMat2<double>[trans], 32 0.08 0.08 1.02x
CuVector::AddDiagMat2<double>[trans], 64 0.31 0.31 1.00x
CuVector::AddDiagMat2<double>[trans], 128 1.13 1.15 1.01x
CuVector::AddDiagMat2<double>[trans], 256 3.80 3.84 1.01x
CuVector::AddDiagMat2<double>[trans], 512 11.19 11.20 1.00x
CuVector::AddDiagMat2<double>[trans], 1024 19.49 19.50 1.00x
CuVector::AddDiagMat2<double>[trans], 2048 29.03 29.04 1.00x
CuVector::AddDiagMat2<double>[trans], 4096 27.89 28.01 1.00x
CuVector::AddDiagMat2<double>[trans], 8192 31.72 31.78 1.00x
CuVector::AddDiagMat2<float>[no-trans], 32 0.09 0.09 1.02x
CuVector::AddDiagMat2<float>[no-trans], 64 0.34 0.34 0.99x
CuVector::AddDiagMat2<float>[no-trans], 128 1.30 1.31 1.01x
CuVector::AddDiagMat2<float>[no-trans], 256 4.42 4.45 1.01x
CuVector::AddDiagMat2<float>[no-trans], 512 14.53 14.68 1.01x
CuVector::AddDiagMat2<float>[no-trans], 1024 33.19 33.38 1.01x
CuVector::AddDiagMat2<float>[no-trans], 2048 53.54 53.70 1.00x
CuVector::AddDiagMat2<float>[no-trans], 4096 60.99 61.07 1.00x
CuVector::AddDiagMat2<float>[no-trans], 8192 64.56 64.60 1.00x
CuVector::AddDiagMat2<float>[trans], 32 0.09 0.09 1.02x
CuVector::AddDiagMat2<float>[trans], 64 0.33 0.33 1.00x
CuVector::AddDiagMat2<float>[trans], 128 1.20 1.21 1.01x
CuVector::AddDiagMat2<float>[trans], 256 4.05 4.17 1.03x
CuVector::AddDiagMat2<float>[trans], 512 12.63 12.68 1.00x
CuVector::AddDiagMat2<float>[trans], 1024 23.79 23.90 1.00x
CuVector::AddDiagMat2<float>[trans], 2048 31.41 31.54 1.00x
CuVector::AddDiagMat2<float>[trans], 4096 49.98 50.08 1.00x
CuVector::AddDiagMat2<float>[trans], 8192 49.10 49.35 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 32 0.08 0.08 0.99x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 64 0.28 0.28 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 128 0.91 0.90 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 256 2.67 2.67 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 512 5.06 5.06 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 1024 10.73 10.74 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 2048 14.66 14.63 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 4096 14.20 14.20 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[no-trans], 8192 15.75 15.77 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 32 0.08 0.08 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 64 0.32 0.32 1.02x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 128 1.21 1.21 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 256 3.84 3.86 1.01x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 512 8.59 8.60 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 1024 13.45 13.45 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 2048 15.64 15.65 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 4096 16.29 16.29 1.00x
CuVector::AddDiagMatMat<double>[no-trans],[trans], 8192 16.46 16.47 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 32 0.08 0.08 1.02x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 64 0.31 0.32 1.01x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 128 1.13 1.15 1.02x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 256 3.85 3.83 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 512 8.13 8.13 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 1024 13.27 13.30 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 2048 15.57 15.59 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 4096 16.04 16.06 1.00x
CuVector::AddDiagMatMat<double>[trans],[no-trans], 8192 16.41 16.41 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 32 0.08 0.08 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 64 0.28 0.27 0.99x
CuVector::AddDiagMatMat<double>[trans],[trans], 128 0.91 0.90 0.99x
CuVector::AddDiagMatMat<double>[trans],[trans], 256 2.67 2.67 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 512 5.02 5.05 1.01x
CuVector::AddDiagMatMat<double>[trans],[trans], 1024 10.73 10.73 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 2048 14.73 14.75 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 4096 14.23 14.24 1.00x
CuVector::AddDiagMatMat<double>[trans],[trans], 8192 15.73 15.76 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 32 0.08 0.08 0.97x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 64 0.29 0.29 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 128 0.97 0.97 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 256 2.94 2.93 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 512 7.72 7.75 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 1024 12.70 12.74 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 2048 28.51 28.56 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 4096 31.13 31.15 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[no-trans], 8192 31.42 31.45 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 32 0.08 0.08 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 64 0.34 0.34 0.99x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 128 1.29 1.30 1.01x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 256 4.43 4.45 1.01x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 512 14.60 14.72 1.01x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 1024 23.19 23.26 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 2048 29.96 29.99 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 4096 32.02 32.13 1.00x
CuVector::AddDiagMatMat<float>[no-trans],[trans], 8192 32.80 32.81 1.00x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 32 0.08 0.08 0.93x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 64 0.33 0.33 1.00x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 128 1.20 1.22 1.02x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 256 4.14 4.17 1.01x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 512 12.63 12.72 1.01x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 1024 19.47 19.50 1.00x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 2048 25.32 25.34 1.00x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 4096 30.86 30.86 1.00x
CuVector::AddDiagMatMat<float>[trans],[no-trans], 8192 30.75 30.76 1.00x
CuVector::AddDiagMatMat<float>[trans],[trans], 32 0.08 0.08 1.02x
CuVector::AddDiagMatMat<float>[trans],[trans], 64 0.29 0.29 1.01x
CuVector::AddDiagMatMat<float>[trans],[trans], 128 0.97 0.97 1.01x
CuVector::AddDiagMatMat<float>[trans],[trans], 256 2.91 2.93 1.01x
CuVector::AddDiagMatMat<float>[trans],[trans], 512 7.71 7.78 1.01x
CuVector::AddDiagMatMat<float>[trans],[trans], 1024 12.71 12.72 1.00x
CuVector::AddDiagMatMat<float>[trans],[trans], 2048 28.25 28.27 1.00x
CuVector::AddDiagMatMat<float>[trans],[trans], 4096 31.23 31.24 1.00x
CuVector::AddDiagMatMat<float>[trans],[trans], 8192 31.48 31.53 1.00x
3566a89 to
72711ff
Compare
|
I've merged your fix and changed the condition so that only large thin and tall matrix will be transposed. But it seems this test |
| MatrixTransposeType other_trans = (trans == kTrans ? kNoTrans : kTrans); | ||
| this->AddDiagMatMat(alpha, M, trans, | ||
| M, other_trans, beta); | ||
| KALDI_ASSERT(dim_ == (trans == kNoTrans ? M.NumRows() : M.NumCols())); |
There was a problem hiding this comment.
thanks. Would you mind adding the check, in this branch, that the dim is correct? Like in my PR?
There was a problem hiding this comment.
cancel that, I see that you did that.
There was a problem hiding this comment.
The line above your comment is checking 'dim_'. Do you mean something else?
This is a quick fix for #2552
Tuned for thin and tall matrix while keeping performance unchanged on small matrices.
Tests in cudamatrix and nnet* are passed.