Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/cudamatrix/cu-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,29 @@ static bool GetCudaContext(int32 num_gpus, std::string *debug_str) {
// Our first attempt to get a device context is: we do cudaFree(0) and see if
// that returns no error code. If it succeeds then we have a device
// context. Apparently this is the canonical way to get a context.
if (cudaFree(0) == 0)
if (cudaFree(0) == 0) {
cudaGetLastError(); // Clear any error status.
return true;
}

// The rest of this code represents how we used to get a device context, but
// now its purpose is mainly a debugging one.
std::ostringstream debug_stream;
debug_stream << "num-gpus=" << num_gpus << ". ";
for (int32 device = 0; device < num_gpus; device++) {
cudaSetDevice(device);
cudaError_t e = cudaDeviceSynchronize(); // << CUDA context gets created here.
cudaError_t e = cudaFree(0); // CUDA context gets created here.
if (e == cudaSuccess) {
*debug_str = debug_stream.str();
if (debug_str)
*debug_str = debug_stream.str();
cudaGetLastError(); // Make sure the error state doesn't get returned in
// the next cudaGetLastError().
return true;
}
debug_stream << "Device " << device << ": " << cudaGetErrorString(e) << ". ";
cudaGetLastError(); // Make sure the error state doesn't get returned in
// the next cudaGetLastError().
}
*debug_str = debug_stream.str();
if (debug_str)
*debug_str = debug_stream.str();
return false;
}

Expand Down Expand Up @@ -164,15 +168,15 @@ void CuDevice::SelectGpuId(std::string use_gpu) {
} else {
int32 num_times = 0;
BaseFloat wait_time = 0.0;
while (! got_context) {
while (!got_context) {
int32 sec_sleep = 5;
if (num_times == 0)
KALDI_WARN << "Will try again indefinitely every " << sec_sleep
<< " seconds to get a GPU.";
num_times++;
wait_time += sec_sleep;
Sleep(sec_sleep);
got_context = GetCudaContext(num_gpus, &debug_str);
got_context = GetCudaContext(num_gpus, NULL);
}

KALDI_WARN << "Waited " << wait_time
Expand Down