-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1107] Fix CPUPinned unexpected behaviour #12031
Changes from 4 commits
9e23ac9
42dc88f
ca1a0f9
92ecd95
1738b78
4ca556a
5c02c35
af4590d
3cbf4eb
1b7611a
ec51ada
156b368
e7a340f
e316cb8
01b0d6f
0ea50cc
fea651f
b361966
cfdcb29
3914ffd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,110 @@ inline __device__ bool __is_supported_cuda_architecture() { | |
} | ||
#endif // __CUDACC__ | ||
|
||
/*! | ||
* \brief Check CUDA error. | ||
* \param msg Message to print if an error occured. | ||
*/ | ||
#define CHECK_CUDA_ERROR(msg) \ | ||
{ \ | ||
cudaError_t e = cudaGetLastError(); \ | ||
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected CUDA call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for CUDA errors after invocation of the expression. | ||
*/ | ||
#define CUDA_CALL(func) \ | ||
{ \ | ||
cudaError_t e = (func); \ | ||
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ | ||
<< "CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuBLAS call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuBLAS errors after invocation of the expression. | ||
*/ | ||
#define CUBLAS_CALL(func) \ | ||
{ \ | ||
cublasStatus_t e = (func); \ | ||
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ | ||
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuSolver call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuSolver errors after invocation of the expression. | ||
*/ | ||
#define CUSOLVER_CALL(func) \ | ||
{ \ | ||
cusolverStatus_t e = (func); \ | ||
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ | ||
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuRAND call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuRAND errors after invocation of the expression. | ||
*/ | ||
#define CURAND_CALL(func) \ | ||
{ \ | ||
curandStatus_t e = (func); \ | ||
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ | ||
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected NVRTC call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for NVRTC errors after invocation of the expression. | ||
*/ | ||
#define NVRTC_CALL(x) \ | ||
{ \ | ||
nvrtcResult result = x; \ | ||
CHECK_EQ(result, NVRTC_SUCCESS) \ | ||
<< #x " failed with error " \ | ||
<< nvrtcGetErrorString(result); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected CUDA driver call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for CUDA driver errors after invocation of the expression. | ||
*/ | ||
#define CUDA_DRIVER_CALL(func) \ | ||
{ \ | ||
CUresult e = (func); \ | ||
if (e != CUDA_SUCCESS) { \ | ||
char const * err_msg = nullptr; \ | ||
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ | ||
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ | ||
} else { \ | ||
LOG(FATAL) << "CUDA Driver: " << err_msg; \ | ||
} \ | ||
} \ | ||
} | ||
|
||
|
||
#if !defined(_MSC_VER) | ||
#define CUDA_UNROLL _Pragma("unroll") | ||
#define CUDA_NOUNROLL _Pragma("nounroll") | ||
#else | ||
#define CUDA_UNROLL | ||
#define CUDA_NOUNROLL | ||
#endif | ||
|
||
namespace mxnet { | ||
namespace common { | ||
/*! \brief common utils for cuda */ | ||
|
@@ -179,113 +283,38 @@ inline DType __device__ CudaMin(DType a, DType b) { | |
return a < b ? a : b; | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace common | ||
} // namespace mxnet | ||
|
||
/*! | ||
* \brief Check CUDA error. | ||
* \param msg Message to print if an error occured. | ||
*/ | ||
#define CHECK_CUDA_ERROR(msg) \ | ||
{ \ | ||
cudaError_t e = cudaGetLastError(); \ | ||
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected CUDA call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for CUDA errors after invocation of the expression. | ||
*/ | ||
#define CUDA_CALL(func) \ | ||
{ \ | ||
cudaError_t e = (func); \ | ||
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ | ||
<< "CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuBLAS call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuBLAS errors after invocation of the expression. | ||
*/ | ||
#define CUBLAS_CALL(func) \ | ||
{ \ | ||
cublasStatus_t e = (func); \ | ||
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ | ||
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuSolver call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuSolver errors after invocation of the expression. | ||
*/ | ||
#define CUSOLVER_CALL(func) \ | ||
{ \ | ||
cusolverStatus_t e = (func); \ | ||
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ | ||
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuRAND call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuRAND errors after invocation of the expression. | ||
*/ | ||
#define CURAND_CALL(func) \ | ||
{ \ | ||
curandStatus_t e = (func); \ | ||
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ | ||
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ | ||
class SetDevice { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to DeviceScope, add a member function SetDevice, and remove all occurrences of |
||
public: | ||
/*! \brief default constructor only restores previous device upon going out of scope */ | ||
SetDevice() { | ||
#if MXNET_USE_CUDA | ||
CUDA_CALL(cudaGetDevice(&restore_device_)); | ||
#else | ||
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; | ||
#endif | ||
} | ||
|
||
/*! | ||
* \brief Protected NVRTC call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for NVRTC errors after invocation of the expression. | ||
*/ | ||
#define NVRTC_CALL(x) \ | ||
{ \ | ||
nvrtcResult result = x; \ | ||
CHECK_EQ(result, NVRTC_SUCCESS) \ | ||
<< #x " failed with error " \ | ||
<< nvrtcGetErrorString(result); \ | ||
/*! \brief standard constuctor is cudaSetDevice + restore previous device */ | ||
explicit SetDevice(int device) { | ||
#if MXNET_USE_CUDA | ||
CUDA_CALL(cudaGetDevice(&restore_device_)); | ||
CUDA_CALL(cudaSetDevice(device)); | ||
#else | ||
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; | ||
#endif | ||
} | ||
|
||
/*! | ||
* \brief Protected CUDA driver call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for CUDA driver errors after invocation of the expression. | ||
*/ | ||
#define CUDA_DRIVER_CALL(func) \ | ||
{ \ | ||
CUresult e = (func); \ | ||
if (e != CUDA_SUCCESS) { \ | ||
char const * err_msg = nullptr; \ | ||
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ | ||
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ | ||
} else { \ | ||
LOG(FATAL) << "CUDA Driver: " << err_msg; \ | ||
} \ | ||
} \ | ||
~SetDevice() { | ||
CUDA_CALL(cudaSetDevice(restore_device_)); | ||
} | ||
|
||
private: | ||
int restore_device_; | ||
}; | ||
|
||
#if !defined(_MSC_VER) | ||
#define CUDA_UNROLL _Pragma("unroll") | ||
#define CUDA_NOUNROLL _Pragma("nounroll") | ||
#else | ||
#define CUDA_UNROLL | ||
#define CUDA_NOUNROLL | ||
#endif | ||
} // namespace cuda | ||
} // namespace common | ||
} // namespace mxnet | ||
|
||
/*! | ||
* \brief Determine major version number of the gpu's cuda compute architecture. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -720,6 +720,9 @@ class CommDevice : public Comm { | |
int n = static_cast<int>(gpus.size()); | ||
int enabled = 0; | ||
std::vector<int> p2p(n*n); | ||
|
||
// Restores active device to what it was before EnableP2P | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (int i = 0; i < n; ++i) { | ||
cudaSetDevice(gpus[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @eric-haibin-lin suggested, replace this with a class method by your RAII object SetDevice (before renaming) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace this with a class method by the SetDevice (renaming) object as @eric-haibin-lin suggested. |
||
for (int j = 0; j < n; j++) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,6 +339,7 @@ class CommDeviceTree : public CommDevice { | |
int n = static_cast<int>(gpus.size()); | ||
int enabled = 0; | ||
std::vector<int> p2p(n*n); | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (int i = 0; i < n; ++i) { | ||
cudaSetDevice(gpus[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
for (int j = 0; j < n; j++) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -428,6 +428,7 @@ class KVStoreNCCL : public KVStoreLocal { | |
mutate_vars.push_back(ptr(dst[i])->var()); | ||
} | ||
Engine::Get()->PushSync([this](RunContext rctx) { | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (auto cur : nccl_data_) { | ||
CUDA_CALL(cudaSetDevice(cur.second.dev_id)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
CUDA_CALL(cudaStreamSynchronize(cur.second.stream)); | ||
|
@@ -479,6 +480,7 @@ class KVStoreNCCL : public KVStoreLocal { | |
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU)); | ||
std::vector<ncclComm_t> comms(devs.size()); | ||
ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0])); | ||
mxnet::common::cuda::SetDevice set_device; | ||
for (size_t i = 0; i < devs.size(); ++i) { | ||
NCCLEntry e; | ||
e.dev_id = device_ids_[i]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also fix the documentation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!