diff --git a/src/cudamatrix/cu-vector-speed-test.cc b/src/cudamatrix/cu-vector-speed-test.cc index 258e0840815..82c471db9f9 100644 --- a/src/cudamatrix/cu-vector-speed-test.cc +++ b/src/cudamatrix/cu-vector-speed-test.cc @@ -191,6 +191,31 @@ template void TestCuVectorAddDiagMatMat(int32 dim, } +template void TestCuVectorAddDiagMat2OnVariousShapes( + int32 dim, MatrixTransposeType trans) { + BaseFloat time_in_secs = 0.02; + int32 size = 1024 * 32; + CuVector v(trans == kNoTrans ? size / dim : dim); + v.SetRandn(); + CuMatrix N(size / dim, dim); + N.SetRandn(); + + Timer tim; + int32 iter = 0; + + for (; tim.Elapsed() < time_in_secs; iter++) { + v.AddDiagMat2(1.0, N, trans, 0.0); + } + + BaseFloat fdim = size; + BaseFloat gflops = (fdim * iter) / (tim.Elapsed() * 1.0e+09); + KALDI_LOG << "For CuVector::AddDiagMat2Shapes" << NameOf() + << (trans == kTrans ? "[trans]" : "[no-trans]") << ", for dim = (" + << size / dim << ", " << dim << "), speed was " << gflops << " gigaflops."; +} + + + template void TestCuVectorAddDiagMat2(int32 dim, MatrixTransposeType trans) { BaseFloat time_in_secs = 0.02; CuVector v(dim); @@ -343,7 +368,6 @@ template void TestCuVectorApplyCeilingNoCount(int32 dim) { template void CudaVectorSpeedTest() { std::vector sizes; - sizes.push_back(16); sizes.push_back(32); sizes.push_back(64); sizes.push_back(128); @@ -369,6 +393,10 @@ template void CudaVectorSpeedTest() { TestCuVectorAddDiagMatMat(sizes[s], kTrans, kNoTrans); TestCuVectorAddDiagMatMat(sizes[s], kTrans, kTrans); } + for (int32 s = 0; s < ns; s++) { + TestCuVectorAddDiagMat2OnVariousShapes(sizes[s], kNoTrans); + TestCuVectorAddDiagMat2OnVariousShapes(sizes[s], kTrans); + } for (int32 s = 0; s < ns; s++) { TestCuVectorAddDiagMat2(sizes[s], kNoTrans); TestCuVectorAddDiagMat2(sizes[s], kTrans); @@ -415,4 +443,3 @@ int main() { #endif KALDI_LOG << "Tests succeeded."; } - diff --git a/src/cudamatrix/cu-vector.cc b/src/cudamatrix/cu-vector.cc index f85d20d37f1..0bb652808f2 100644 --- a/src/cudamatrix/cu-vector.cc +++ b/src/cudamatrix/cu-vector.cc @@ -569,8 +569,13 @@ void CuVectorBase::AddDiagMat2(Real alpha, const CuMatrixBase &M, if (CuDevice::Instantiate().Enabled()) { if (dim_ == 0) return; MatrixTransposeType other_trans = (trans == kTrans ? kNoTrans : kTrans); - this->AddDiagMatMat(alpha, M, trans, - M, other_trans, beta); + KALDI_ASSERT(dim_ == (trans == kNoTrans ? M.NumRows() : M.NumCols())); + if (trans == kTrans && M.NumCols() < 512 && M.NumRows() > 8192) { + CuMatrix MT(M, kTrans); + this->AddDiagMatMat(alpha, MT, other_trans, MT, trans, beta); + } else { + this->AddDiagMatMat(alpha, M, trans, M, other_trans, beta); + } } else #endif {