diff --git a/src/cudamatrix/cu-rand.cc b/src/cudamatrix/cu-rand.cc index 9d55a3f655e..20439834a98 100644 --- a/src/cudamatrix/cu-rand.cc +++ b/src/cudamatrix/cu-rand.cc @@ -68,7 +68,8 @@ void CuRand::RandUniform(CuMatrixBase *tgt) { // may vary). CuMatrix tmp(tgt->NumRows(), tgt->NumCols(), kUndefined, kStrideEqualNumCols); - CURAND_SAFE_CALL(curandGenerateUniformWrap(gen_, tmp.Data(), tmp.NumRows() * tmp.Stride())); + size_t s = static_cast(tmp.NumRows()) * static_cast(tmp.Stride()); + CURAND_SAFE_CALL(curandGenerateUniformWrap(gen_, tmp.Data(), s)); tgt->CopyFromMat(tmp); CuDevice::Instantiate().AccuProfile(__func__, tim); } else @@ -84,7 +85,8 @@ void CuRand::RandUniform(CuMatrix *tgt) { if (CuDevice::Instantiate().Enabled()) { CuTimer tim; // Here we don't need to use 'tmp' matrix, - CURAND_SAFE_CALL(curandGenerateUniformWrap(gen_, tgt->Data(), tgt->NumRows() * tgt->Stride())); + size_t s = static_cast(tgt->NumRows()) * static_cast(tgt->Stride()); + CURAND_SAFE_CALL(curandGenerateUniformWrap(gen_, tgt->Data(), s)); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif