Skip to content

Commit

Permalink
[MXNET-1107] Fix CPUPinned unexpected behaviour (apache#12031)
Browse files Browse the repository at this point in the history
* Fix CPUPinned unexpected behaviour

* fix lint

* add guards

* Actually, this may affect perf

* trigger ci

* fix lint

* fix documentation

* fix for dist_sync_device

* add guard

* fix bug with memory

* try fix for gluon mp interaction

* blah

* trigger jenkins

* Try fix for gluon multiprocessing bug

Thanks Nvidia!

* edit

* try nvidia fix

* address Haibin and Lin's comments

* get rid of blank line in Makefile
  • Loading branch information
ctcyang authored and piyushghai committed Oct 19, 2018
1 parent 4c1713f commit a232b02
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 109 deletions.
4 changes: 2 additions & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ struct Context {
return dev_type;
}
/*!
* \brief Returns dev_id for kGPU, 0 otherwise
* \brief Returns dev_id for kGPU and kCPUPinned, 0 otherwise
*/
inline int real_dev_id() const {
if (dev_type == kGPU) return dev_id;
if (dev_type == kCPUPinned || dev_type == kGPU) return dev_id;
return 0;
}
/*!
Expand Down
222 changes: 122 additions & 100 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,110 @@ inline __device__ bool __is_supported_cuda_architecture() {
}
#endif // __CUDACC__

/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected cuBLAS call.
* \param func Expression to call.
*
* It checks for cuBLAS errors after invocation of the expression.
*/
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
}

/*!
* \brief Protected cuSolver call.
* \param func Expression to call.
*
* It checks for cuSolver errors after invocation of the expression.
*/
#define CUSOLVER_CALL(func) \
{ \
cusolverStatus_t e = (func); \
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
}

/*!
* \brief Protected cuRAND call.
* \param func Expression to call.
*
* It checks for cuRAND errors after invocation of the expression.
*/
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
}

/*!
* \brief Protected NVRTC call.
* \param func Expression to call.
*
* It checks for NVRTC errors after invocation of the expression.
*/
#define NVRTC_CALL(x) \
{ \
nvrtcResult result = x; \
CHECK_EQ(result, NVRTC_SUCCESS) \
<< #x " failed with error " \
<< nvrtcGetErrorString(result); \
}

/*!
* \brief Protected CUDA driver call.
* \param func Expression to call.
*
* It checks for CUDA driver errors after invocation of the expression.
*/
#define CUDA_DRIVER_CALL(func) \
{ \
CUresult e = (func); \
if (e != CUDA_SUCCESS) { \
char const * err_msg = nullptr; \
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
} else { \
LOG(FATAL) << "CUDA Driver: " << err_msg; \
} \
} \
}


#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
#else
#define CUDA_UNROLL
#define CUDA_NOUNROLL
#endif

namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
Expand Down Expand Up @@ -179,113 +283,31 @@ inline DType __device__ CudaMin(DType a, DType b) {
return a < b ? a : b;
}

} // namespace cuda
} // namespace common
} // namespace mxnet

/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
class DeviceStore {
public:
/*! \brief default constructor- only optionally restores previous device */
explicit DeviceStore(bool restore = true) : restore_(restore) {
if (restore_)
CUDA_CALL(cudaGetDevice(&restore_device_));
}

/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
~DeviceStore() {
if (restore_)
CUDA_CALL(cudaSetDevice(restore_device_));
}

/*!
* \brief Protected cuBLAS call.
* \param func Expression to call.
*
* It checks for cuBLAS errors after invocation of the expression.
*/
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
void SetDevice(int device) {
CUDA_CALL(cudaSetDevice(device));
}

/*!
* \brief Protected cuSolver call.
* \param func Expression to call.
*
* It checks for cuSolver errors after invocation of the expression.
*/
#define CUSOLVER_CALL(func) \
{ \
cusolverStatus_t e = (func); \
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
}
private:
int restore_device_;
bool restore_;
};

/*!
* \brief Protected cuRAND call.
* \param func Expression to call.
*
* It checks for cuRAND errors after invocation of the expression.
*/
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
}

/*!
* \brief Protected NVRTC call.
* \param func Expression to call.
*
* It checks for NVRTC errors after invocation of the expression.
*/
#define NVRTC_CALL(x) \
{ \
nvrtcResult result = x; \
CHECK_EQ(result, NVRTC_SUCCESS) \
<< #x " failed with error " \
<< nvrtcGetErrorString(result); \
}

/*!
* \brief Protected CUDA driver call.
* \param func Expression to call.
*
* It checks for CUDA driver errors after invocation of the expression.
*/
#define CUDA_DRIVER_CALL(func) \
{ \
CUresult e = (func); \
if (e != CUDA_SUCCESS) { \
char const * err_msg = nullptr; \
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
} else { \
LOG(FATAL) << "CUDA Driver: " << err_msg; \
} \
} \
}


#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
#else
#define CUDA_UNROLL
#define CUDA_NOUNROLL
#endif
} // namespace cuda
} // namespace common
} // namespace mxnet

/*!
* \brief Determine major version number of the gpu's cuda compute architecture.
Expand Down
3 changes: 2 additions & 1 deletion src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ CUfunction CudaModule::Chunk::GetFunction(
CHECK_EQ(ctx.dev_mask(), Context::kGPU)
<< "CUDA Runtime compilation only supports Nvidia GPU.";
auto iter = mod_.find(ctx.dev_id);
mxnet::common::cuda::DeviceStore device_store;
CUmodule module;
if (iter != mod_.end()) {
module = iter->second;
} else {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
device_store.SetDevice(ctx.dev_id);
CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
mod_[ctx.dev_id] = module;
}
Expand Down
10 changes: 8 additions & 2 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ template <std::size_t kNumGpus, std::size_t kStreams>
RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
Context const& ctx) {
RunContext ret;
#if MXNET_USE_CUDA
mxnet::common::cuda::DeviceStore device_store;
#endif
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
std::size_t use_counter;
CUDA_CALL(cudaSetDevice(ctx.dev_id));
device_store.SetDevice(ctx.dev_id);
{
std::lock_guard<std::mutex> lock{mutex_};
auto&& counter = gpu_cnt_.at(ctx.dev_id);
Expand Down Expand Up @@ -101,13 +104,16 @@ template <std::size_t kNumGpus, std::size_t kStreams>
RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
Context const& ctx) {
RunContext ret;
#if MXNET_USE_CUDA
mxnet::common::cuda::DeviceStore device_store;
#endif
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
CUDA_CALL(cudaSetDevice(ctx.dev_id));
device_store.SetDevice(ctx.dev_id);
{
std::lock_guard<std::mutex> lock{mutex_};
if (gpu_io_streams_.at(ctx.dev_id) == nullptr) {
Expand Down
5 changes: 4 additions & 1 deletion src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,11 @@ class CommDevice : public Comm {
int n = static_cast<int>(gpus.size());
int enabled = 0;
std::vector<int> p2p(n*n);

// Restores active device to what it was before EnableP2P
mxnet::common::cuda::DeviceStore device_store;
for (int i = 0; i < n; ++i) {
cudaSetDevice(gpus[i]);
device_store.SetDevice(gpus[i]);
for (int j = 0; j < n; j++) {
int access;
cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]);
Expand Down
3 changes: 2 additions & 1 deletion src/kvstore/comm_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ class CommDeviceTree : public CommDevice {
int n = static_cast<int>(gpus.size());
int enabled = 0;
std::vector<int> p2p(n*n);
mxnet::common::cuda::DeviceStore device_store;
for (int i = 0; i < n; ++i) {
cudaSetDevice(gpus[i]);
device_store.SetDevice(gpus[i]);
for (int j = 0; j < n; j++) {
int access;
cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]);
Expand Down
6 changes: 4 additions & 2 deletions src/kvstore/kvstore_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ class KVStoreNCCL : public KVStoreLocal {
mutate_vars.push_back(ptr(dst[i])->var());
}
Engine::Get()->PushSync([this](RunContext rctx) {
mxnet::common::cuda::DeviceStore device_store;
for (auto cur : nccl_data_) {
CUDA_CALL(cudaSetDevice(cur.second.dev_id));
device_store.SetDevice(cur.second.dev_id);
CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
}
},
Expand Down Expand Up @@ -479,12 +480,13 @@ class KVStoreNCCL : public KVStoreLocal {
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
std::vector<ncclComm_t> comms(devs.size());
ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0]));
mxnet::common::cuda::DeviceStore device_store;
for (size_t i = 0; i < devs.size(); ++i) {
NCCLEntry e;
e.dev_id = device_ids_[i];
e.comm = comms[i];
e.rank = i;
cudaSetDevice(e.dev_id);
device_store.SetDevice(e.dev_id);
cudaStreamCreate(&(e.stream));
nccl_data_[device_ids_[i]] = e;
}
Expand Down
Loading

0 comments on commit a232b02

Please sign in to comment.