diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index b8d6e7edbf5..6abba7fba6b 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -109,9 +109,10 @@ void CuDevice::Initialize() { cudaSetDevice(device_id_); // Initialize CUBLAS. CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_)); + CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread)); // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); - + CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); } } @@ -243,8 +244,10 @@ void CuDevice::FinalizeActiveGpu() { // the main thread. // Initialize CUBLAS. CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_)); + CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread)); // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); + CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); // Notify the user which GPU is being userd. char name[128];